Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 27, 2024
1 parent 5d20f79 commit ec23e68
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Expand Up @@ -630,7 +630,7 @@ def test_fsdp_strategy_save_optimizer_states(tmp_path, wrap_min_params):

@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
def test_fsdp_strategy_load_optimizer_states(wrap_min_params, tmp_path):
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
Expand Down Expand Up @@ -694,14 +694,17 @@ def test_fsdp_strategy_load_optimizer_states(tmp_path, wrap_min_params):
("32-true", torch.float32),
],
)
def test_configure_model(precision, expected_dtype):
def test_configure_model(precision, expected_dtype, tmp_path):
"""Test that the module under configure_model gets moved to the right device and dtype."""
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
max_epochs=1,
enable_checkpointing=False,
logger=False,
)

class MyModel(BoringModel):
Expand Down Expand Up @@ -899,7 +902,7 @@ def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
def test_module_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype, tmp_path):
"""Test that the module under the init-context gets moved to the right device and dtype."""

class Model(BoringModel):
Expand All @@ -915,12 +918,15 @@ def on_train_start(self):

def _run_setup_assertions(empty_init, expected_device):
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy={torch.nn.Linear}),
precision=precision,
max_steps=1,
barebones=True,
enable_checkpointing=False,
logger=False,
)
with trainer.init_module(empty_init=empty_init):
model = Model()
Expand Down

0 comments on commit ec23e68

Please sign in to comment.