From 89e14c92fd5096b9edb4597785a615a3cc52d5fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 13 Dec 2022 11:01:33 +0100 Subject: [PATCH] Fix DDPStrategy import in app framework after #14952 (#16029) (cherry picked from commit fcd3195e6817840c37b0c1908660071a6dfbbf1a) --- src/lightning_app/components/multi_node/lite.py | 6 +++--- tests/tests_app/components/multi_node/test_lite.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 003c6dfb1c75af..5f3351142ad5ef 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -40,8 +40,8 @@ def run( try: pkg = importlib.import_module(pkg_name) lites.append(pkg.LightningLite) - strategies.append(pkg.strategies.DDPSpawnShardedStrategy) - strategies.append(pkg.strategies.DDPSpawnStrategy) + strategies.append(pkg.strategies.DDPShardedStrategy) + strategies.append(pkg.strategies.DDPStrategy) mps_accelerators.append(pkg.accelerators.MPSAccelerator) except (ImportError, ModuleNotFoundError): continue @@ -81,7 +81,7 @@ def pre_fn(lite, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, tuple(strategies)): + elif isinstance(strategy, tuple(strategies)) and strategy._start_method in ("spawn", "fork"): raise ValueError("DDP Spawned strategies aren't supported yet.") kwargs["strategy"] = strategy diff --git a/tests/tests_app/components/multi_node/test_lite.py b/tests/tests_app/components/multi_node/test_lite.py index 9b8aa29779fd21..2a60cf71df4d06 100644 --- a/tests/tests_app/components/multi_node/test_lite.py +++ b/tests/tests_app/components/multi_node/test_lite.py @@ -97,7 +97,7 @@ def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: @pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available") def test_lite_run_executor_invalid_strategy_instances(): with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): - _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy()) + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPStrategy(start_method="spawn")) with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): - _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy()) + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPShardedStrategy(start_method="spawn"))