Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 7, 2024
1 parent ad5b990 commit 439220d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def _load_checkpoint(
rank: int,
strict: bool = True,
) -> Dict[str, Any]:

from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_load_checkpoint_no_state(tmp_path):


@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._lazy_load", Mock())
@mock.patch("lightning.fabric.strategies.model_parallel._lazy_load", Mock())
def test_load_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only load one FSDP model per checkpoint."""
strategy = FSDPStrategy()
Expand All @@ -333,6 +333,7 @@ def test_load_checkpoint_one_fsdp_module_required(tmp_path):

# A raw nn.Module instead of a dictionary is ok
model = Mock(spec=nn.Module)
model.parameters.return_value = [torch.zeros(2, 1)]
path = tmp_path / "full.ckpt"
path.touch()
strategy.load_checkpoint(path=path, state=model)
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_fabric/strategies/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_load_checkpoint_no_state(tmp_path):

@RunIf(min_torch="2.3")
@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._lazy_load", Mock())
@mock.patch("lightning.fabric.strategies.model_parallel._lazy_load", Mock())
def test_load_checkpoint_one_dist_module_required(tmp_path):
"""Test that the ModelParallelStrategy strategy can only load one distributed model per checkpoint."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
Expand All @@ -262,6 +262,7 @@ def test_load_checkpoint_one_dist_module_required(tmp_path):

# A raw nn.Module instead of a dictionary is ok
model = Mock(spec=nn.Module)
model.parameters.return_value = [torch.zeros(2, 1)]
path = tmp_path / "full.ckpt"
path.touch()
strategy.load_checkpoint(path=path, state=model)
Expand Down

0 comments on commit 439220d

Please sign in to comment.