I have some PyTorch code that performs training using activation checkpointing that I’m attempting to port to Habana Gaudi 2 machines.
Original code
The original code looks like this:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing
apply_activation_checkpointing(
model,
check_fn=lambda x: isinstance(x, BasicLayer3D)
)
My attempt to port
and now I’m attempting to use the habana checkpointing wrapper from deepspeed:
apply_activation_checkpointing(
model,
check_fn=lambda x: isinstance(x, BasicLayer3D),
checkpoint_wrapper_fn=habana_checkpoint_wrapper,
)
where habana_checkpoint_wrapper is identical to torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper
except I specify that checkpoint_fn = deepspeed.runtime.activation_checkpointing.checkpointing.checkpoint
.
This using the Habana Deepspeed wrapper, all running on a brand new install of Pytorch with Habana on Intel ITDC.
Error
I get the following error:
File "/home/sdp/.venv/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 169, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "/home/sdp/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
TypeError: checkpoint() got an unexpected keyword argument 'rollout_step'