Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP checkpoint saving raises internal deprecation warnings #119802

Open
carmocca opened this issue Feb 13, 2024 · 2 comments
Open

FSDP checkpoint saving raises internal deprecation warnings #119802

carmocca opened this issue Feb 13, 2024 · 2 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@carmocca
Copy link
Contributor

carmocca commented Feb 13, 2024

馃悰 Describe the bug

The messages are:

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py:1132: UserWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  warnings.warn(DEPRECATE_MSG)

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:148: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():

Minimal repro:

import os
import torch.cuda
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel

def get_sharded_state_dict_context(module):
    from torch.distributed.fsdp.api import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType

    state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
    optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
    state_dict_type_context = FullyShardedDataParallel.state_dict_type(
        module=module,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=state_dict_config,
        optim_state_dict_config=optim_state_dict_config,
    )
    return state_dict_type_context  # type: ignore[return-value]

def work(rank):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "1234"
    dist.init_process_group("nccl", world_size=2, rank=rank)
    torch.cuda.set_device(rank)
    device = torch.device("cuda", rank)

    model = nn.Linear(100, 50).to(device)
    model = FullyShardedDataParallel(model)
    x = torch.rand(2, 100, device=device)

    y = model(x)

    from torch.distributed.checkpoint import save
    with get_sharded_state_dict_context(model):
        state = {"model": model.state_dict()}
    save(state, checkpoint_id="fsdp_model.pt")

def run():
    mp.spawn(work, nprocs=2)

if __name__ == "__main__":
    run()

First reported in Lightning-AI/pytorch-lightning#19462 (comment)

Versions

2.3.0.dev20240212+cu121

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @LucasLLC

@colesbury colesbury added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 14, 2024
@fegin fegin added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: distributed_checkpoint labels Feb 15, 2024
@fegin
Copy link
Contributor

fegin commented Feb 15, 2024

@LucasLLC We should fix the filesystem.py warning. @carmocca We are switching to DTensor and would like to move to DTensor. init_device_mesh is beta released in 2.2. cc., @wz337

@fegin fegin assigned wz337 and fegin Feb 15, 2024
@wz337
Copy link
Contributor

wz337 commented Feb 15, 2024

@LucasLLC We should fix the filesystem.py warning. @carmocca We are switching to DTensor and would like to move to DTensor. init_device_mesh is beta released in 2.2. cc., @wz337

@carmocca If you are interested in finding out more about DTensor, here is a get started page. https://pytorch.org/tutorials/recipes/distributed_device_mesh.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants