Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode on MPS #15748

Merged
merged 58 commits into from Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f880c73
Fix restarting attribute for lr finder
justusschock Nov 10, 2022
c13c1a6
update lite executor
justusschock Nov 21, 2022
8fceed0
update trainer executor
justusschock Nov 21, 2022
13cb379
update spawn executor
justusschock Nov 21, 2022
49b20b3
add multinode component tests
justusschock Nov 21, 2022
f80f95b
add testing helpers
justusschock Nov 21, 2022
07ad6aa
add lite tests
justusschock Nov 21, 2022
a1b2e61
add trainer tests
justusschock Nov 21, 2022
e59cbba
Merge branch 'master' into fix/multinod_mps
justusschock Nov 21, 2022
21231c1
Revert "Fix restarting attribute for lr finder"
justusschock Nov 10, 2022
fdaee38
Merge remote-tracking branch 'origin/fix/multinod_mps' into fix/multi…
justusschock Nov 21, 2022
0b3157a
update changelog
justusschock Nov 21, 2022
b1709f0
Merge branch 'master' into fix/multinod_mps
justusschock Nov 22, 2022
cb98589
update skip reasons
justusschock Nov 22, 2022
0e7acf0
skipif
Borda Nov 22, 2022
e4dde73
update skip conditions to only use L.lite and L.pytorch if available
justusschock Nov 23, 2022
b64b19e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2022
e0989d0
typo
justusschock Nov 23, 2022
29ae553
fix ci
justusschock Nov 23, 2022
be15c13
Update src/lightning_app/CHANGELOG.md
justusschock Nov 24, 2022
7b93a09
test workflow
justusschock Nov 24, 2022
2ac55e4
update trainer
justusschock Nov 24, 2022
f87fc0e
update workflow
justusschock Nov 24, 2022
5760141
Merge branch 'master' into fix/multinod_mps
justusschock Nov 24, 2022
e4911e6
update tests
justusschock Nov 24, 2022
44fe438
Update ci-app-tests.yml
justusschock Nov 24, 2022
31ee30b
Update tests/tests_app/components/multi_node/test_lite.py
awaelchli Nov 24, 2022
f598ecb
debug
awaelchli Nov 24, 2022
7914c33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2022
1365d90
debug
awaelchli Nov 24, 2022
d0db952
Merge remote-tracking branch 'origin/fix/multinod_mps' into fix/multi…
awaelchli Nov 24, 2022
a794104
update executors to work with standalone and unified
justusschock Nov 24, 2022
1dfe1df
Merge branch 'master' into fix/multinod_mps
justusschock Nov 24, 2022
4bb8580
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2022
bfd8845
add reason for skipif
justusschock Nov 24, 2022
269952d
Merge branch 'master' into fix/multinod_mps
awaelchli Nov 24, 2022
5295c45
update test
justusschock Nov 25, 2022
b9af23d
update test
justusschock Nov 25, 2022
bb73075
update test
justusschock Nov 25, 2022
7eeb773
update test
justusschock Nov 25, 2022
cfe9226
update test
justusschock Nov 25, 2022
4934ac1
update test
justusschock Nov 25, 2022
63d40b5
Merge branch 'master' into fix/multinod_mps
justusschock Nov 25, 2022
5cb2f45
Merge branch 'master' into fix/multinod_mps
justusschock Nov 25, 2022
2563a8e
Merge branch 'master' into fix/multinod_mps
justusschock Nov 26, 2022
279ac50
Merge branch 'master' into fix/multinod_mps
awaelchli Dec 5, 2022
f0a83e8
Merge branch 'master' into fix/multinod_mps
Borda Dec 6, 2022
95efc21
Merge branch 'master' into fix/multinod_mps
Borda Dec 7, 2022
9d69ce1
Apply suggestions from code review
Borda Dec 7, 2022
ae285c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
6e0e00b
ll
Borda Dec 7, 2022
24ca54f
switch skipif
Borda Dec 7, 2022
935a162
.
Borda Dec 7, 2022
0c0aa6d
another
Borda Dec 7, 2022
733b6f3
Merge branch 'master' into fix/multinod_mps
Borda Dec 7, 2022
13e832c
Merge branch 'master' into fix/multinod_mps
Borda Dec 8, 2022
c24e1a0
Apply suggestions from code review
Borda Dec 8, 2022
ee0f19a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -46,6 +46,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed race condition to over-write the frontend with app infos ([#15398](https://github.com/Lightning-AI/lightning/pull/15398))


- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))


-


Expand Down
23 changes: 20 additions & 3 deletions src/lightning_app/components/multi_node/lite.py
@@ -1,4 +1,5 @@
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -31,6 +32,7 @@ def run(
nprocs: int,
):
from lightning.lite import LightningLite
from lightning.lite.accelerators import MPSAccelerator
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy

# Used to configure PyTorch progress group
Expand All @@ -52,7 +54,18 @@ def run(
def pre_fn(lite, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"

if MPSAccelerator.is_available():
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn(
"Forcing accelerator=cpu as other accelerators (specifically MPS) are not supported "
"by PyTorch for distributed training on mps capable devices"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
else:
kwargs["accelerator"] = "auto"
justusschock marked this conversation as resolved.
Show resolved Hide resolved
strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
Expand All @@ -61,14 +74,18 @@ def pre_fn(lite, *args, **kwargs):
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
raise ValueError("DDP Spawned strategies aren't supported yet.")

kwargs["strategy"] = strategy

return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LiteMultiNode(MultiNode):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/components/multi_node/pytorch_spawn.py
Expand Up @@ -88,7 +88,7 @@ def run(
elif world_size > 1:
raise Exception("Torch distributed should be available.")

work_run(world_size, node_rank, global_rank, local_rank)
return work_run(world_size, node_rank, global_rank, local_rank)
justusschock marked this conversation as resolved.
Show resolved Hide resolved


class PyTorchSpawnMultiNode(MultiNode):
Expand Down
20 changes: 16 additions & 4 deletions src/lightning_app/components/multi_node/trainer.py
@@ -1,4 +1,5 @@
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type

Expand Down Expand Up @@ -30,8 +31,9 @@ def run(
node_rank: int,
nprocs: int,
):
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
from lightning.pytorch import Trainer as LTrainer
from lightning.pytorch.accelerators import MPSAccelerator
from lightning.pytorch.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
from pytorch_lightning import Trainer as PLTrainer

# Used to configure PyTorch progress group
Expand All @@ -50,7 +52,15 @@ def run(
def pre_fn(trainer, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
kwargs["accelerator"] = "auto"
if MPSAccelerator.is_available():
justusschock marked this conversation as resolved.
Show resolved Hide resolved
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"

if old_acc_value != kwargs["accelerator"]:
warnings.warn(
"Forcing accelerator=cpu as other accelerators (specifically MPS) are not supported "
Borda marked this conversation as resolved.
Show resolved Hide resolved
"by PyTorch for distributed training on mps capable devices"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
Expand All @@ -59,15 +69,17 @@ def pre_fn(trainer, *args, **kwargs):
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
raise Exception("DDP Spawned strategies aren't supported yet.")
raise ValueError("DDP Spawned strategies aren't supported yet.")
kwargs["strategy"] = strategy
return {}, args, kwargs

tracer = Tracer()
tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn)
tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn)
tracer._instrument()
work_run()
ret_val = work_run()
tracer._restore()
return ret_val


class LightningTrainerMultiNode(MultiNode):
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions tests/tests_app/components/multi_node/test_lite.py
@@ -0,0 +1,98 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import lightning as L
from lightning_app.components.multi_node.lite import _LiteRunExecutor


def dummy_callable(**kwargs):
ll = L.lite.LightningLite(**kwargs)
return ll._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(L.lite.LightningLite, "__init__", dummy_init):
ret_val = _LiteRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


@pytest.mark.skipif(not module_available("lightning.lite"), reason="Lightning.lite not available")
@pytest.mark.skipif(not L.lite.accelerators.MPSAccelerator.is_available(), reason="mps not available")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
(
{
"devices": 1,
"num_nodes": 1,
"accelerator": "gpu",
},
{"devices": 8, "num_nodes": 7, "accelerator": "auto"},
),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):

# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
if L.lite.accelerators.MPSAccelerator.is_available():
args_expected["accelerator"] = "cpu"

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
assert env_vars["LT_CLI_USED"] == "1"


def test_lite_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.lite.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.lite.strategies.DDPSpawnShardedStrategy())
97 changes: 97 additions & 0 deletions tests/tests_app/components/multi_node/test_trainer.py
@@ -0,0 +1,97 @@
import os
from copy import deepcopy
from functools import partial
from unittest import mock

import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call

import lightning as L
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor


def dummy_callable(**kwargs):
t = L.pytorch.Trainer(**kwargs)
return t._all_passed_kwargs


def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs


def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(L.pytorch.Trainer, "__init__", dummy_init):
ret_val = _LightningTrainerRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars


@pytest.mark.skipif(not module_available("lightning.pytorch"))
@pytest.mark.skipif(not L.pytorch.accelerators.MPSAccelerator.is_available())
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")

with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected


@pytest.mark.parametrize(
"args_given,args_expected",
[
(
{
"devices": 1,
"num_nodes": 1,
"accelerator": "gpu",
},
{"devices": 8, "num_nodes": 7, "accelerator": "auto"},
Borda marked this conversation as resolved.
Show resolved Hide resolved
),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):

# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
if L.pytorch.accelerators.MPSAccelerator.is_available():
args_expected["accelerator"] = "cpu"

ret_val, env_vars = _get_args_after_tracer_injection(**args_given)

for k, v in args_expected.items():
assert ret_val[k] == v

assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"


def test_trainer_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.pytorch.strategies.DDPSpawnStrategy())

with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.pytorch.strategies.DDPSpawnShardedStrategy())
Empty file.
27 changes: 27 additions & 0 deletions tests/tests_app/helpers/utils.py
@@ -0,0 +1,27 @@
import re
from contextlib import contextmanager
from typing import Optional, Type

import pytest


@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")