diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 003c6dfb1c75af..5f3351142ad5ef 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -40,8 +40,8 @@ def run( try: pkg = importlib.import_module(pkg_name) lites.append(pkg.LightningLite) - strategies.append(pkg.strategies.DDPSpawnShardedStrategy) - strategies.append(pkg.strategies.DDPSpawnStrategy) + strategies.append(pkg.strategies.DDPShardedStrategy) + strategies.append(pkg.strategies.DDPStrategy) mps_accelerators.append(pkg.accelerators.MPSAccelerator) except (ImportError, ModuleNotFoundError): continue @@ -81,7 +81,7 @@ def pre_fn(lite, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, tuple(strategies)): + elif isinstance(strategy, tuple(strategies)) and strategy._start_method in ("spawn", "fork"): raise ValueError("DDP Spawned strategies aren't supported yet.") kwargs["strategy"] = strategy diff --git a/tests/tests_app/components/multi_node/test_lite.py b/tests/tests_app/components/multi_node/test_lite.py index 9b8aa29779fd21..2a60cf71df4d06 100644 --- a/tests/tests_app/components/multi_node/test_lite.py +++ b/tests/tests_app/components/multi_node/test_lite.py @@ -97,7 +97,7 @@ def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: @pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available") def test_lite_run_executor_invalid_strategy_instances(): with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): - _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy()) + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPStrategy(start_method="spawn")) with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): - _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy()) + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPShardedStrategy(start_method="spawn"))