Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow flexible and easy to configure HSDP #19502

Open
Liyang90 opened this issue Feb 20, 2024 · 8 comments · May be fixed by #19504
Open

Allow flexible and easy to configure HSDP #19502

Liyang90 opened this issue Feb 20, 2024 · 8 comments · May be fixed by #19504
Labels
discussion In a discussion stage feature Is an improvement or enhancement strategy: fsdp Fully Sharded Data Parallel
Milestone

Comments

@Liyang90
Copy link
Contributor

Liyang90 commented Feb 20, 2024

Description & Motivation

The FSDPStrategy can use hybrid sharding strategy to shard across smaller sets of ranks in the global dist group. However, it is not flexible enough to let user easily specify the sharding scale.

Pitch

The FSDPStrategy can use hybrid sharding strategy to shard across smaller sets of ranks in the global dist group. Currently there are two path to use it in Lightning:

  1. Specify sharding_strategy as one of the hybrid sharding strategies. This will shard within one node, and replicate across nodes.
  2. Specify sharding_strategy as one of the hybrid sharding strategies, and provide process_group as kwards to FSDPStrategy. This let user specify how large the sharding scale is. However, it is not easy for user to insert torch dist groups creation code and prepare the process_group ahead of time, because Lightning handles torch dist init_process_group automatically in trainer, or the fabric launcher.

So I'm looking forward to a easier way to use HSDP within Lightning, like:
FSDPStrategy(..., sharding_strategy="HYBRID_SHARD", fsdp_size=16)
to easily shard at specified scale, and let Lightning handle process_group preparation for PyTorch FSPD wrapper.

Alternatives

No response

Additional context

No response

cc @Borda @awaelchli @carmocca

@Liyang90 Liyang90 added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Feb 20, 2024
@Liyang90 Liyang90 linked a pull request Feb 20, 2024 that will close this issue
7 tasks
@awaelchli
Copy link
Member

awaelchli commented Feb 21, 2024

I agree we need a better way to specify this. PyTorch 2.2 introduced the device mesh, so we should probably use that to specify the size, rather than having the user construct the process group matrix themselves.

Having an argument like you suggest could work, but it might be confusing to have this for anything other than hybrid sharding.

@awaelchli awaelchli added discussion In a discussion stage strategy: fsdp Fully Sharded Data Parallel and removed needs triage Waiting to be triaged by maintainers labels Feb 21, 2024
@awaelchli awaelchli added this to the 2.3 milestone Feb 21, 2024
@awaelchli
Copy link
Member

awaelchli commented Feb 21, 2024

Since passing in a device mesh already works

from torch.distributed.device_mesh import init_device_mesh
mesh = init_device_mesh("cuda", (2, 4))

strategy = FSDPStrategy(..., device_mesh=mesh)

I suggest that we simplify this by allowing the user to set a tuple device_mesh=(2, 4) (in addition to DeviceMesh) and internally we initialize the device mesh for them if it's a tuple. Then we don't need to introduce a new argument.

@Liyang90
Copy link
Contributor Author

Liyang90 commented Feb 22, 2024

Since passing in a device mesh already works

from torch.distributed.device_mesh import init_device_mesh
mesh = init_device_mesh("cuda", (2, 4))

strategy = FSDPStrategy(..., device_mesh=mesh)

I suggest that we simplify this by allowing the user to set a tuple device_mesh=(2, 4) (in addition to DeviceMesh) and internally we initialize the device mesh for them if it's a tuple. Then we don't need to introduce a new argument.

This seems reasonable as well, and simpler. In pytorch codeprocess_group and device_mesh end up being handled in the same function: https://github.com/pytorch/pytorch/blob/1d14adfa66e2ae437253eebe223710588648eee7/torch/distributed/fsdp/_init_utils.py#L152C5-L152C47

But pytorch FSDP doc does not document the device_mesh argument very well, and users would need to know what the numbers in device_mesh tuple means (which is the fsdp size and which is ddp size).

@awaelchli
Copy link
Member

Great to hear you like it. Would you be interested to draft it? It would be relatively straightforward:

  1. Store the device_mesh argument as attribute in the strategy
  2. in Strategy setup initialize it and pass it to the FSDP wrapper:
    module = FullyShardedDataParallel(

But pytorch FSDP doc does not document the device_mesh argument very well, and users would need to know what the numbers in device_mesh tuple means (which is the fsdp size and which is ddp size).

Yes agreed. This is typical for PyTorch, their distributed features are always very short on docs. I think we would want to document this well on our side (both the API and the user guide). We have a relatively thorough guide already: https://lightning.ai/docs/fabric/stable/advanced/model_parallel/fsdp.html

@carmocca
Copy link
Member

There's a device_mesh recipe available at https://pytorch.org/tutorials/recipes/distributed_device_mesh.html

@Liyang90
Copy link
Contributor Author

Great to hear you like it. Would you be interested to draft it? It would be relatively straightforward:

  1. Store the device_mesh argument as attribute in the strategy
  2. in Strategy setup initialize it and pass it to the FSDP wrapper:
    module = FullyShardedDataParallel(

But pytorch FSDP doc does not document the device_mesh argument very well, and users would need to know what the numbers in device_mesh tuple means (which is the fsdp size and which is ddp size).

Yes agreed. This is typical for PyTorch, their distributed features are always very short on docs. I think we would want to document this well on our side (both the API and the user guide). We have a relatively thorough guide already: https://lightning.ai/docs/fabric/stable/advanced/model_parallel/fsdp.html

Sure. I will iterate in the draft PR above, when I have some bandwidth.

@Liyang90
Copy link
Contributor Author

Liyang90 commented Mar 5, 2024

I updated the PR #19504 as suggested.

@Quantum1921
Copy link

#QuantumDominator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants