From 36aecde6957ce0c6ebb4e4b8bf8c1738a4c253b6 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 8 Dec 2022 13:05:18 +0100 Subject: [PATCH] Multinode on MPS (#15748) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix restarting attribute for lr finder * update lite executor * update trainer executor * update spawn executor * add multinode component tests * add testing helpers * add lite tests * add trainer tests * update changelog * update trainer * update workflow * update tests * debug * add reason for skipif * Apply suggestions from code review * switch skipif Co-authored-by: Jirka Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .github/workflows/ci-app-tests.yml | 2 +- src/lightning_app/CHANGELOG.md | 4 + .../components/multi_node/lite.py | 37 +++++-- .../components/multi_node/pytorch_spawn.py | 2 +- .../components/multi_node/trainer.py | 37 +++++-- .../components/multi_node/__init__.py | 0 .../components/multi_node/test_lite.py | 103 ++++++++++++++++++ .../components/multi_node/test_trainer.py | 99 +++++++++++++++++ .../utilities/packaging/test_build_spec.py | 4 +- 9 files changed, 268 insertions(+), 20 deletions(-) create mode 100644 tests/tests_app/components/multi_node/__init__.py create mode 100644 tests/tests_app/components/multi_node/test_lite.py create mode 100644 tests/tests_app/components/multi_node/test_trainer.py diff --git a/.github/workflows/ci-app-tests.yml b/.github/workflows/ci-app-tests.yml index d19a408309bc4..b89e3145da12e 100644 --- a/.github/workflows/ci-app-tests.yml +++ b/.github/workflows/ci-app-tests.yml @@ -94,7 +94,7 @@ jobs: - name: Adjust tests if: ${{ matrix.pkg-name == 'lightning' }} - run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app" --target_import="lightning.app" + run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app,lightning_lite,pytorch_lightning" --target_import="lightning.app,lightning.lite,lightning.pytorch" - name: Adjust examples if: ${{ matrix.pkg-name != 'lightning' }} diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 9e750e4132b13..9e7862d329c64 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810)) +- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748)) + + - Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801)) @@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642)) - Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650)) - Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712)) +- Fixed setting property to the LightningFlow ([#15750](https://github.com/Lightning-AI/lightning/pull/15750)) - Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714)) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 2a9b33b0880d1..36709d409e1a0 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -1,4 +1,6 @@ +import importlib import os +import warnings from dataclasses import dataclass from typing import Any, Callable, Type @@ -30,8 +32,16 @@ def run( node_rank: int, nprocs: int, ): - from lightning.lite import LightningLite - from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy + lites = [] + strategies = [] + mps_accelerators = [] + + for pkg_name in ("lightning.lite", "lightning_" + "lite"): + pkg = importlib.import_module(pkg_name) + lites.append(pkg.LightningLite) + strategies.append(pkg.strategies.DDPSpawnShardedStrategy) + strategies.append(pkg.strategies.DDPSpawnStrategy) + mps_accelerators.append(pkg.accelerators.MPSAccelerator) # Used to configure PyTorch progress group os.environ["MASTER_ADDR"] = main_address @@ -52,7 +62,15 @@ def run( def pre_fn(lite, *args, **kwargs): kwargs["devices"] = nprocs kwargs["num_nodes"] = num_nodes - kwargs["accelerator"] = "auto" + + if any(acc.is_available() for acc in mps_accelerators): + old_acc_value = kwargs.get("accelerator", "auto") + kwargs["accelerator"] = "cpu" + + if old_acc_value != kwargs["accelerator"]: + warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.") + else: + kwargs["accelerator"] = "auto" strategy = kwargs.get("strategy", None) if strategy: if isinstance(strategy, str): @@ -60,15 +78,20 @@ def pre_fn(lite, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)): - raise Exception("DDP Spawned strategies aren't supported yet.") + elif isinstance(strategy, tuple(strategies)): + raise ValueError("DDP Spawned strategies aren't supported yet.") + + kwargs["strategy"] = strategy + return {}, args, kwargs tracer = Tracer() - tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn) + for ll in lites: + tracer.add_traced(ll, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + ret_val = work_run() tracer._restore() + return ret_val class LiteMultiNode(MultiNode): diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 3119ffc51e0b5..013bdbcaec347 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -88,7 +88,7 @@ def run( elif world_size > 1: raise Exception("Torch distributed should be available.") - work_run(world_size, node_rank, global_rank, local_rank) + return work_run(world_size, node_rank, global_rank, local_rank) class PyTorchSpawnMultiNode(MultiNode): diff --git a/src/lightning_app/components/multi_node/trainer.py b/src/lightning_app/components/multi_node/trainer.py index 222f71ce59557..8f25b71d622c1 100644 --- a/src/lightning_app/components/multi_node/trainer.py +++ b/src/lightning_app/components/multi_node/trainer.py @@ -1,4 +1,6 @@ +import importlib import os +import warnings from dataclasses import dataclass from typing import Any, Callable, Type @@ -30,9 +32,16 @@ def run( node_rank: int, nprocs: int, ): - from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy - from lightning.pytorch import Trainer as LTrainer - from pytorch_lightning import Trainer as PLTrainer + trainers = [] + strategies = [] + mps_accelerators = [] + + for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"): + pkg = importlib.import_module(pkg_name) + trainers.append(pkg.Trainer) + strategies.append(pkg.strategies.DDPSpawnShardedStrategy) + strategies.append(pkg.strategies.DDPSpawnStrategy) + mps_accelerators.append(pkg.accelerators.MPSAccelerator) # Used to configure PyTorch progress group os.environ["MASTER_ADDR"] = main_address @@ -50,7 +59,15 @@ def run( def pre_fn(trainer, *args, **kwargs): kwargs["devices"] = nprocs kwargs["num_nodes"] = num_nodes - kwargs["accelerator"] = "auto" + if any(acc.is_available() for acc in mps_accelerators): + old_acc_value = kwargs.get("accelerator", "auto") + kwargs["accelerator"] = "cpu" + + if old_acc_value != kwargs["accelerator"]: + warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.") + else: + kwargs["accelerator"] = "auto" + strategy = kwargs.get("strategy", None) if strategy: if isinstance(strategy, str): @@ -58,16 +75,18 @@ def pre_fn(trainer, *args, **kwargs): strategy = "ddp" elif strategy == "ddp_sharded_spawn": strategy = "ddp_sharded" - elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)): - raise Exception("DDP Spawned strategies aren't supported yet.") + elif isinstance(strategy, tuple(strategies)): + raise ValueError("DDP Spawned strategies aren't supported yet.") + kwargs["strategy"] = strategy return {}, args, kwargs tracer = Tracer() - tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn) - tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn) + for trainer in trainers: + tracer.add_traced(trainer, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + ret_val = work_run() tracer._restore() + return ret_val class LightningTrainerMultiNode(MultiNode): diff --git a/tests/tests_app/components/multi_node/__init__.py b/tests/tests_app/components/multi_node/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_app/components/multi_node/test_lite.py b/tests/tests_app/components/multi_node/test_lite.py new file mode 100644 index 0000000000000..9b8aa29779fd2 --- /dev/null +++ b/tests/tests_app/components/multi_node/test_lite.py @@ -0,0 +1,103 @@ +import os +from copy import deepcopy +from functools import partial +from unittest import mock + +import pytest +from lightning_utilities.core.imports import module_available +from tests_app.helpers.utils import no_warning_call + +import lightning_lite as ll +from lightning_app.components.multi_node.lite import _LiteRunExecutor + + +class DummyLite(ll.LightningLite): + def run(self): + pass + + +def dummy_callable(**kwargs): + lite = DummyLite(**kwargs) + return lite._all_passed_kwargs + + +def dummy_init(self, **kwargs): + self._all_passed_kwargs = kwargs + + +def _get_args_after_tracer_injection(**kwargs): + with mock.patch.object(ll.LightningLite, "__init__", dummy_init): + ret_val = _LiteRunExecutor.run( + local_rank=0, + work_run=partial(dummy_callable, **kwargs), + main_address="1.2.3.4", + main_port=5, + node_rank=6, + num_nodes=7, + nprocs=8, + ) + env_vars = deepcopy(os.environ) + return ret_val, env_vars + + +def check_lightning_lite_mps(): + if module_available("lightning_lite"): + return ll.accelerators.MPSAccelerator.is_available() + return False + + +@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available") +@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) +def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): + warning_str = ( + r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " + + "by PyTorch for distributed training on mps capable devices" + ) + if accelerator_expected != accelerator_given: + warning_context = pytest.warns(UserWarning, match=warning_str) + else: + warning_context = no_warning_call(match=warning_str + "*") + + with warning_context: + ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) + assert ret_val["accelerator"] == accelerator_expected + + +@pytest.mark.parametrize( + "args_given,args_expected", + [ + ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), + ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), + ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), + ], +) +@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test") +def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict): + + # ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args) + if ll.accelerators.MPSAccelerator.is_available(): + args_expected["accelerator"] = "cpu" + + ret_val, env_vars = _get_args_after_tracer_injection(**args_given) + + for k, v in args_expected.items(): + assert ret_val[k] == v + + assert env_vars["MASTER_ADDR"] == "1.2.3.4" + assert env_vars["MASTER_PORT"] == "5" + assert env_vars["GROUP_RANK"] == "6" + assert env_vars["RANK"] == str(0 + 6 * 8) + assert env_vars["LOCAL_RANK"] == "0" + assert env_vars["WORLD_SIZE"] == str(7 * 8) + assert env_vars["LOCAL_WORLD_SIZE"] == "8" + assert env_vars["TORCHELASTIC_RUN_ID"] == "1" + assert env_vars["LT_CLI_USED"] == "1" + + +@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available") +def test_lite_run_executor_invalid_strategy_instances(): + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy()) + + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy()) diff --git a/tests/tests_app/components/multi_node/test_trainer.py b/tests/tests_app/components/multi_node/test_trainer.py new file mode 100644 index 0000000000000..c86e0968e2ab0 --- /dev/null +++ b/tests/tests_app/components/multi_node/test_trainer.py @@ -0,0 +1,99 @@ +import os +from copy import deepcopy +from functools import partial +from unittest import mock + +import pytest +from lightning_utilities.core.imports import module_available +from tests_app.helpers.utils import no_warning_call + +import pytorch_lightning as pl +from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor + + +def dummy_callable(**kwargs): + t = pl.Trainer(**kwargs) + return t._all_passed_kwargs + + +def dummy_init(self, **kwargs): + self._all_passed_kwargs = kwargs + + +def _get_args_after_tracer_injection(**kwargs): + with mock.patch.object(pl.Trainer, "__init__", dummy_init): + ret_val = _LightningTrainerRunExecutor.run( + local_rank=0, + work_run=partial(dummy_callable, **kwargs), + main_address="1.2.3.4", + main_port=5, + node_rank=6, + num_nodes=7, + nprocs=8, + ) + env_vars = deepcopy(os.environ) + return ret_val, env_vars + + +def check_lightning_pytorch_and_mps(): + if module_available("pytorch_lightning"): + return pl.accelerators.MPSAccelerator.is_available() + return False + + +@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required") +@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")]) +def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected): + warning_str = ( + r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported " + + "by PyTorch for distributed training on mps capable devices" + ) + if accelerator_expected != accelerator_given: + warning_context = pytest.warns(UserWarning, match=warning_str) + else: + warning_context = no_warning_call(match=warning_str + "*") + + with warning_context: + ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given) + assert ret_val["accelerator"] == accelerator_expected + + +@pytest.mark.parametrize( + "args_given,args_expected", + [ + ({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}), + ({"strategy": "ddp_spawn"}, {"strategy": "ddp"}), + ({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}), + ], +) +@pytest.mark.skipif(not module_available("pytorch"), reason="Lightning is not available") +def test_trainer_run_executor_arguments_choices( + args_given: dict, + args_expected: dict, +): + + if pl.accelerators.MPSAccelerator.is_available(): + args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately + + ret_val, env_vars = _get_args_after_tracer_injection(**args_given) + + for k, v in args_expected.items(): + assert ret_val[k] == v + + assert env_vars["MASTER_ADDR"] == "1.2.3.4" + assert env_vars["MASTER_PORT"] == "5" + assert env_vars["GROUP_RANK"] == "6" + assert env_vars["RANK"] == str(0 + 6 * 8) + assert env_vars["LOCAL_RANK"] == "0" + assert env_vars["WORLD_SIZE"] == str(7 * 8) + assert env_vars["LOCAL_WORLD_SIZE"] == "8" + assert env_vars["TORCHELASTIC_RUN_ID"] == "1" + + +@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available") +def test_trainer_run_executor_invalid_strategy_instances(): + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy()) + + with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."): + _, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy()) diff --git a/tests/tests_app/utilities/packaging/test_build_spec.py b/tests/tests_app/utilities/packaging/test_build_spec.py index ba497a5efbdb4..70c4a60374b67 100644 --- a/tests/tests_app/utilities/packaging/test_build_spec.py +++ b/tests/tests_app/utilities/packaging/test_build_spec.py @@ -29,7 +29,7 @@ def test_build_config_requirements_provided(): assert spec.requirements == [ "dask", "pandas", - "pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string` + "pytorch_lightning==1.5.9", "git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0", ] assert spec == BuildConfig.from_dict(spec.to_dict()) @@ -50,7 +50,7 @@ def test_build_config_dockerfile_provided(): spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu") assert not spec.requirements # ugly hack due to replacing `pytorch_lightning string - assert "pytorchlightning/pytorch_" + "lightning" in spec.dockerfile.data[0] + assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0] class DockerfileLightningTestApp(LightningTestApp):