-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
test_trainer.py
99 lines (77 loc) · 3.62 KB
/
test_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
from copy import deepcopy
from functools import partial
from unittest import mock
import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call
import pytorch_lightning as pl
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor
def dummy_callable(**kwargs):
t = pl.Trainer(**kwargs)
return t._all_passed_kwargs
def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs
def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(pl.Trainer, "__init__", dummy_init):
ret_val = _LightningTrainerRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars
def check_lightning_pytorch_and_mps():
if module_available("pytorch_lightning"):
return pl.accelerators.MPSAccelerator.is_available()
return False
@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required")
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")
with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected
@pytest.mark.parametrize(
"args_given,args_expected",
[
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
@pytest.mark.skipif(not module_available("torch"), reason="PyTorch is not available")
def test_trainer_run_executor_arguments_choices(
args_given: dict,
args_expected: dict,
):
if pl.accelerators.MPSAccelerator.is_available():
args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately
ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
for k, v in args_expected.items():
assert ret_val[k] == v
assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available")
def test_trainer_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy())
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy())