Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Distributed Load Updates or Returns state_dict #125096

Open
mvpatel2000 opened this issue Apr 27, 2024 · 5 comments
Open

PyTorch Distributed Load Updates or Returns state_dict #125096

mvpatel2000 opened this issue Apr 27, 2024 · 5 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Apr 27, 2024

馃殌 The feature, motivation and pitch

Torch distributed checkpoint load_state_dict (https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L20)
updates the passed in state_dict (and returns it). This function is deprecated in torch 2.3 in favor of load (https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L48) which neither returns nor updates the passed in state_dict. Instead, it only calls load_state_dict on Stateful elements in the specified state_dict.

Unfortunately, this new API is greatly limiting. For example, in Composer's state_dict passed for checkpointing, we also store various RNG tensors in a dict for determinism. In order to use the new API, we have to rewrap everything in a Stateful class, which is a somewhat pointless abstraction. Instead, we prefer to receive a loaded state_dict and then manually call load_state_dict on appropriate subitems.

Can we modify load to update the passed in state_dict? This would entail adding:

state_dict[key] = statetful_sd[key]

after https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_loader.py#L172-L177

Alternatives

No response

Additional context

No response

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC

@mvpatel2000
Copy link
Contributor Author

@pytorchbot label "oncall: distributed"

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Apr 27, 2024
@fegin
Copy link
Contributor

fegin commented May 1, 2024

@LucasLLC Any thought about this request?

@LucasLLC LucasLLC added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 7, 2024
@LucasLLC
Copy link
Contributor

LucasLLC commented May 7, 2024

Sorry for the delay! Taking a look

@LucasLLC
Copy link
Contributor

LucasLLC commented May 7, 2024

@mvpatel2000 this makes a lot of sense to me. Would you like to submit a PR or should I?

@mvpatel2000
Copy link
Contributor Author

@mvpatel2000 this makes a lot of sense to me. Would you like to submit a PR or should I?

@LucasLLC if you can submit a PR that would be awesome :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants