Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible and easy to use HSDP setting #19504

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
312caee
Add fsdp_size for FSDPStrategy
Liyang90 Jan 17, 2024
45c1123
fix import
Liyang90 Jan 17, 2024
0ddc51d
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Feb 20, 2024
c952536
Add flexible HSDP in fabric
Liyang90 Feb 20, 2024
8fc2404
minor update
Liyang90 Feb 20, 2024
da3900f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8311be1
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 1, 2024
d1d719a
Use device_mesh arg to set flexible HSDP with a Tuple
Liyang90 Mar 4, 2024
3315893
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
4652b74
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 5, 2024
4049f60
minor fix
Liyang90 Mar 5, 2024
9c14afe
add simple docs
awaelchli Mar 8, 2024
1f2c3ff
correct doc string
Liyang90 Apr 1, 2024
07f7c1b
set as explicit args in FSDPStrategy
Liyang90 Apr 4, 2024
2ab0423
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Apr 4, 2024
899e032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
4259df2
update fsdp tests
Liyang90 Apr 18, 2024
dbe22f3
Type check error
Liyang90 Apr 18, 2024
2320a4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
b0d4783
merge
Liyang90 Apr 18, 2024
9d7dfbe
type check
Liyang90 Apr 18, 2024
483f745
Merge branch 'master' into hybrid_fsdp_stage
Liyang90 May 16, 2024
ba0b10b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Expand Up @@ -125,10 +125,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(sharding size, replication size)` that defines over how many devices to shard and
Liyang90 marked this conversation as resolved.
Show resolved Hide resolved
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -253,6 +257,12 @@ def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

# if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the tuple specification is a feature of Lightning, we should list the device_mesh parameter explicitly in the init args (see the docstring I added). The kwargs are for things that we pass down to FSDP directly.

So I suggest to set self.device_mesh and update this attribute here to the actual DeviceMesh 😃

Copy link
Contributor Author

@Liyang90 Liyang90 Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If device_mesh is separated from kwargs, we will fail the check in _init_sharding_strategy in the self.__init__.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed now. @awaelchli could you take a look again?


@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Expand Up @@ -125,10 +125,14 @@ class FSDPStrategy(ParallelStrategy):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(sharding size, replication size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -260,6 +264,12 @@ def setup_environment(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down