Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use torch 2.2 distributed checkpoint APIs for FSDP #19497

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

carmocca
Copy link
Member

@carmocca carmocca commented Feb 19, 2024

What does this PR do?

Fixes #19462

Resources

TODO:

  • Test with different checkpoints. PyTorch docs say:

    There is no guarantees of Backwards Compatibility across PyTorch versions for saved state_dicts.

  • Apply the same changes to the Trainer
  • Run tests with _TORCH_GREATER_EQUAL_2_2 = False since CI only tests 2.2

📚 Documentation preview 📚: https://pytorch-lightning--19497.org.readthedocs.build/en/19497/

@carmocca carmocca added refactor strategy: fsdp Fully Sharded Data Parallel labels Feb 19, 2024
@carmocca carmocca self-assigned this Feb 19, 2024
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Feb 19, 2024
Copy link
Member Author

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently blocked by pytorch/pytorch#119800 (comment)

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

# `cpu_offload` disabled because when used with `full_state_dict` only rank 0 loads the state dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that the other path sets rank0_only=False. I asked if this could be configurable in pytorch/pytorch#112837 (comment)

@@ -440,6 +439,7 @@ def save_checkpoint(
)
if filter is not None and self._state_dict_type == "sharded":
# https://github.com/pytorch/pytorch/issues/105379
# FIXME: revisit support with new APIs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to myself

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric refactor strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FSDP checkpointing uses deprecated APIs with PyTorch 2.2
1 participant