Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode on MPS #15748

Merged
merged 58 commits into from Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f880c73
Fix restarting attribute for lr finder
justusschock Nov 10, 2022
c13c1a6
update lite executor
justusschock Nov 21, 2022
8fceed0
update trainer executor
justusschock Nov 21, 2022
13cb379
update spawn executor
justusschock Nov 21, 2022
49b20b3
add multinode component tests
justusschock Nov 21, 2022
f80f95b
add testing helpers
justusschock Nov 21, 2022
07ad6aa
add lite tests
justusschock Nov 21, 2022
a1b2e61
add trainer tests
justusschock Nov 21, 2022
e59cbba
Merge branch 'master' into fix/multinod_mps
justusschock Nov 21, 2022
21231c1
Revert "Fix restarting attribute for lr finder"
justusschock Nov 10, 2022
fdaee38
Merge remote-tracking branch 'origin/fix/multinod_mps' into fix/multi…
justusschock Nov 21, 2022
0b3157a
update changelog
justusschock Nov 21, 2022
b1709f0
Merge branch 'master' into fix/multinod_mps
justusschock Nov 22, 2022
cb98589
update skip reasons
justusschock Nov 22, 2022
0e7acf0
skipif
Borda Nov 22, 2022
e4dde73
update skip conditions to only use L.lite and L.pytorch if available
justusschock Nov 23, 2022
b64b19e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2022
e0989d0
typo
justusschock Nov 23, 2022
29ae553
fix ci
justusschock Nov 23, 2022
be15c13
Update src/lightning_app/CHANGELOG.md
justusschock Nov 24, 2022
7b93a09
test workflow
justusschock Nov 24, 2022
2ac55e4
update trainer
justusschock Nov 24, 2022
f87fc0e
update workflow
justusschock Nov 24, 2022
5760141
Merge branch 'master' into fix/multinod_mps
justusschock Nov 24, 2022
e4911e6
update tests
justusschock Nov 24, 2022
44fe438
Update ci-app-tests.yml
justusschock Nov 24, 2022
31ee30b
Update tests/tests_app/components/multi_node/test_lite.py
awaelchli Nov 24, 2022
f598ecb
debug
awaelchli Nov 24, 2022
7914c33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2022
1365d90
debug
awaelchli Nov 24, 2022
d0db952
Merge remote-tracking branch 'origin/fix/multinod_mps' into fix/multi…
awaelchli Nov 24, 2022
a794104
update executors to work with standalone and unified
justusschock Nov 24, 2022
1dfe1df
Merge branch 'master' into fix/multinod_mps
justusschock Nov 24, 2022
4bb8580
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2022
bfd8845
add reason for skipif
justusschock Nov 24, 2022
269952d
Merge branch 'master' into fix/multinod_mps
awaelchli Nov 24, 2022
5295c45
update test
justusschock Nov 25, 2022
b9af23d
update test
justusschock Nov 25, 2022
bb73075
update test
justusschock Nov 25, 2022
7eeb773
update test
justusschock Nov 25, 2022
cfe9226
update test
justusschock Nov 25, 2022
4934ac1
update test
justusschock Nov 25, 2022
63d40b5
Merge branch 'master' into fix/multinod_mps
justusschock Nov 25, 2022
5cb2f45
Merge branch 'master' into fix/multinod_mps
justusschock Nov 25, 2022
2563a8e
Merge branch 'master' into fix/multinod_mps
justusschock Nov 26, 2022
279ac50
Merge branch 'master' into fix/multinod_mps
awaelchli Dec 5, 2022
f0a83e8
Merge branch 'master' into fix/multinod_mps
Borda Dec 6, 2022
95efc21
Merge branch 'master' into fix/multinod_mps
Borda Dec 7, 2022
9d69ce1
Apply suggestions from code review
Borda Dec 7, 2022
ae285c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
6e0e00b
ll
Borda Dec 7, 2022
24ca54f
switch skipif
Borda Dec 7, 2022
935a162
.
Borda Dec 7, 2022
0c0aa6d
another
Borda Dec 7, 2022
733b6f3
Merge branch 'master' into fix/multinod_mps
Borda Dec 7, 2022
13e832c
Merge branch 'master' into fix/multinod_mps
Borda Dec 8, 2022
c24e1a0
Apply suggestions from code review
Borda Dec 8, 2022
ee0f19a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-app-tests.yml
Expand Up @@ -94,7 +94,7 @@ jobs:

- name: Adjust tests
if: ${{ matrix.pkg-name == 'lightning' }}
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app" --target_import="lightning.app"
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app,lightning_lite,pytorch_lightning" --target_import="lightning.app,lightning.lite,lightning.pytorch"

- name: Adjust examples
if: ${{ matrix.pkg-name != 'lightning' }}
Expand Down
4 changes: 4 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))

- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))



- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))

Expand Down Expand Up @@ -102,6 +105,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))
- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))
- Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712))
- Fixed setting property to the LightningFlow ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))
- Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714))


justusschock marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
40 changes: 33 additions & 7 deletions src/lightning_app/components/multi_node/lite.py
@@ -1,4 +1,6 @@
import importlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -30,8 +32,16 @@ def run(
node_rank: int,
nprocs: int,
):
from lightning.lite import LightningLite
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
lites = []
strategies = []
mps_accelerators = []

for pkg_name in ("lightning.lite", "lightning_" + "lite"):
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand All @@ -52,23 +62,39 @@ def run(
def pre_fn(lite, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"

if any(x.is_available() for x in mps_accelerators):
Borda marked this conversation as resolved.
Show resolved Hide resolved
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn(
"Forcing accelerator=cpu as other accelerators (specifically MPS) are not supported "
"by PyTorch for distributed training on mps capable devices"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
else:
kwargs["accelerator"] = "auto"
justusschock marked this conversation as resolved.
Show resolved Hide resolved
strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
if strategy == "ddp_spawn":
strategy = "ddp"
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
elif isinstance(strategy, tuple(strategies)):
raise ValueError("DDP Spawned strategies aren't supported yet.")

kwargs["strategy"] = strategy

return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn)
for ll in lites:
tracer.add_traced(ll, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LiteMultiNode(MultiNode):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/components/multi_node/pytorch_spawn.py
Expand Up @@ -88,7 +88,7 @@ def run(
elif world_size > 1:
raise Exception("Torch distributed should be available.")

work_run(world_size, node_rank, global_rank, local_rank)
return work_run(world_size, node_rank, global_rank, local_rank)
justusschock marked this conversation as resolved.
Show resolved Hide resolved


class PyTorchSpawnMultiNode(MultiNode):
Expand Down
40 changes: 31 additions & 9 deletions src/lightning_app/components/multi_node/trainer.py
@@ -1,4 +1,6 @@
import importlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -30,9 +32,16 @@ def run(
node_rank: int,
nprocs: int,
):
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
from lightning.pytorch import Trainer as LTrainer
from pytorch_lightning import Trainer as PLTrainer
trainers = []
strategies = []
mps_accelerators = []

for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand All @@ -50,24 +59,37 @@ def run(
def pre_fn(trainer, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"
if any(x.is_available() for x in mps_accelerators):
Borda marked this conversation as resolved.
Show resolved Hide resolved
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn(
"Forcing accelerator=cpu as other accelerators (specifically MPS) are not supported "
Borda marked this conversation as resolved.
Show resolved Hide resolved
"by PyTorch for distributed training on mps capable devices"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
else:
kwargs["accelerator"] = "auto"

strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
if strategy == "ddp_spawn":
strategy = "ddp"
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
elif isinstance(strategy, tuple(strategies)):
raise ValueError("DDP Spawned strategies aren't supported yet.")
kwargs["strategy"] = strategy
return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn)
tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn)
for trainer in trainers:
tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LightningTrainerMultiNode(MultiNode):
Expand Down
Empty file.
110 changes: 110 additions & 0 deletions tests/tests_app/components/multi_node/test_lite.py
@@ -0,0 +1,110 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import lightning.lite as ll
from lightning_app.components.multi_node.lite import _LiteRunExecutor


class DummyLite(ll.LightningLite):
def run(self):
pass


def dummy_callable(**kwargs):
lite = DummyLite(**kwargs)
return lite._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(ll.LightningLite, "__init__", dummy_init):
ret_val = _LiteRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


def check_lightning_lite_mps():
if module_available("lightning_lite"):
return ll.accelerators.MPSAccelerator.is_available()
return False


@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
(
{
"devices": 1,
"num_nodes": 1,
"accelerator": "gpu",
},
{"devices": 8, "num_nodes": 7, "accelerator": "auto"},
),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
@pytest.mark.skipif(not module_available("lightning_lite"), reason="Lightning Lite is required for this test")
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):

# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
if ll.accelerators.MPSAccelerator.is_available():
args_expected["accelerator"] = "cpu"

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
assert env_vars["LT_CLI_USED"] == "1"


@pytest.mark.skipif(not module_available("lightning_lite"), reason="Lightning lite not available")
def test_lite_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy())
106 changes: 106 additions & 0 deletions tests/tests_app/components/multi_node/test_trainer.py
@@ -0,0 +1,106 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import pytorch_lightning as pl
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor


def dummy_callable(**kwargs):
t = pl.Trainer(**kwargs)
return t._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(pl.Trainer, "__init__", dummy_init):
ret_val = _LightningTrainerRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


def check_lightning_pytorch_and_mps():
if module_available("pytorch_lightning"):
return pl.accelerators.MPSAccelerator.is_available()
return False


@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
(
{
"devices": 1,
"num_nodes": 1,
"accelerator": "gpu",
},
{"devices": 8, "num_nodes": 7, "accelerator": "auto"},
Borda marked this conversation as resolved.
Show resolved Hide resolved
),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
@pytest.mark.skipif(not module_available("pytorch_lightning"), reason="Pytorch Lightning is not available")
def test_trainer_run_executor_arguments_choices(
args_given: dict,
args_expected: dict,
):

if pl.accelerators.MPSAccelerator.is_available():
args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"


@pytest.mark.skipif(not module_available("pytorch_lightning"), reason="pytorch_lightning not available")
def test_trainer_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy())