Activation checkpointing modules with kwargs in forward

I have some PyTorch code that performs training using activation checkpointing that I’m attempting to port to Habana Gaudi 2 machines.

Original code

The original code looks like this:

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing
apply_activation_checkpointing(
    model,
    check_fn=lambda x: isinstance(x, BasicLayer3D)
)

My attempt to port

and now I’m attempting to use the habana checkpointing wrapper from deepspeed:

apply_activation_checkpointing(
    model,
    check_fn=lambda x: isinstance(x, BasicLayer3D),
    checkpoint_wrapper_fn=habana_checkpoint_wrapper,
)

where habana_checkpoint_wrapper is identical to torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper except I specify that checkpoint_fn = deepspeed.runtime.activation_checkpointing.checkpointing.checkpoint.

This using the Habana Deepspeed wrapper, all running on a brand new install of Pytorch with Habana on Intel ITDC.

Error

I get the following error:

  File "/home/sdp/.venv/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 169, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/home/sdp/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
TypeError: checkpoint() got an unexpected keyword argument 'rollout_step'

The error suggests that you use torch.compile. Under torch.compile, it is recommended to use PyTorch native activation checkpointing as torch.compile treats it in a special way. Non-native versions are not traced correctly by torch.compile e.g. you may need to disable compiler via @torch.compiler.disable decoration.