-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
test_trainer.py
97 lines (77 loc) · 3.49 KB
/
test_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from copy import deepcopy
from functools import partial
from unittest import mock
import pytest
from lightning_utilities.core.imports import module_available
from tests_app.helpers.utils import no_warning_call
import lightning as L
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor
def dummy_callable(**kwargs):
t = L.pytorch.Trainer(**kwargs)
return t._all_passed_kwargs
def dummy_init(self, **kwargs):
self._all_passed_kwargs = kwargs
def _get_args_after_tracer_injection(**kwargs):
with mock.patch.object(L.pytorch.Trainer, "__init__", dummy_init):
ret_val = _LightningTrainerRunExecutor.run(
local_rank=0,
work_run=partial(dummy_callable, **kwargs),
main_address="1.2.3.4",
main_port=5,
node_rank=6,
num_nodes=7,
nprocs=8,
)
env_vars = deepcopy(os.environ)
return ret_val, env_vars
@pytest.mark.skipif(not module_available("lightning.pytorch"))
@pytest.mark.skipif(not L.pytorch.accelerators.MPSAccelerator.is_available())
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
warning_str = (
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
+ "by PyTorch for distributed training on mps capable devices"
)
if accelerator_expected != accelerator_given:
warning_context = pytest.warns(UserWarning, match=warning_str)
else:
warning_context = no_warning_call(match=warning_str + "*")
with warning_context:
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
assert ret_val["accelerator"] == accelerator_expected
@pytest.mark.parametrize(
"args_given,args_expected",
[
(
{
"devices": 1,
"num_nodes": 1,
"accelerator": "gpu",
},
{"devices": 8, "num_nodes": 7, "accelerator": "auto"},
),
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
],
)
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):
# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
if L.pytorch.accelerators.MPSAccelerator.is_available():
args_expected["accelerator"] = "cpu"
ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
for k, v in args_expected.items():
assert ret_val[k] == v
assert env_vars["MASTER_ADDR"] == "1.2.3.4"
assert env_vars["MASTER_PORT"] == "5"
assert env_vars["GROUP_RANK"] == "6"
assert env_vars["RANK"] == str(0 + 6 * 8)
assert env_vars["LOCAL_RANK"] == "0"
assert env_vars["WORLD_SIZE"] == str(7 * 8)
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
def test_trainer_run_executor_invalid_strategy_instances():
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.pytorch.strategies.DDPSpawnStrategy())
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
_, _ = _get_args_after_tracer_injection(strategy=L.pytorch.strategies.DDPSpawnShardedStrategy())