-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
lite.py
121 lines (101 loc) · 4.01 KB
/
lite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import importlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Type
from typing_extensions import Protocol, runtime_checkable
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
from lightning_app.utilities.tracer import Tracer
@runtime_checkable
class _LiteWorkProtocol(Protocol):
@staticmethod
def run() -> None:
...
@dataclass
class _LiteRunExecutor(_PyTorchSpawnRunExecutor):
@staticmethod
def run(
local_rank: int,
work_run: Callable,
main_address: str,
main_port: int,
num_nodes: int,
node_rank: int,
nprocs: int,
):
lites = []
strategies = []
mps_accelerators = []
for pkg_name in ("lightning.lite", "lightning_" + "lite"):
try:
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
except (ImportError, ModuleNotFoundError):
continue
# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
os.environ["MASTER_PORT"] = str(main_port)
# Used to hijack TorchElastic Cluster Environnement.
os.environ["GROUP_RANK"] = str(node_rank)
os.environ["RANK"] = str(local_rank + node_rank * nprocs)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(num_nodes * nprocs)
os.environ["LOCAL_WORLD_SIZE"] = str(nprocs)
os.environ["TORCHELASTIC_RUN_ID"] = "1"
# Used to force Lite to setup the distributed environnement.
os.environ["LT_CLI_USED"] = "1"
# Used to pass information to Lite directly.
def pre_fn(lite, *args, **kwargs):
kwargs["devices"] = nprocs
kwargs["num_nodes"] = num_nodes
if any(acc.is_available() for acc in mps_accelerators):
old_acc_value = kwargs.get("accelerator", "auto")
kwargs["accelerator"] = "cpu"
if old_acc_value != kwargs["accelerator"]:
warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
else:
kwargs["accelerator"] = "auto"
strategy = kwargs.get("strategy", None)
if strategy:
if isinstance(strategy, str):
if strategy == "ddp_spawn":
strategy = "ddp"
elif strategy == "ddp_sharded_spawn":
strategy = "ddp_sharded"
elif isinstance(strategy, tuple(strategies)):
raise ValueError("DDP Spawned strategies aren't supported yet.")
kwargs["strategy"] = strategy
return {}, args, kwargs
tracer = Tracer()
for ll in lites:
tracer.add_traced(ll, "__init__", pre_fn=pre_fn)
tracer._instrument()
ret_val = work_run()
tracer._restore()
return ret_val
class LiteMultiNode(MultiNode):
def __init__(
self,
work_cls: Type["LightningWork"],
cloud_compute: "CloudCompute",
num_nodes: int,
*work_args: Any,
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _LiteWorkProtocol)
# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
work_cls._run_executor_cls = _LiteRunExecutor
super().__init__(
work_cls,
*work_args,
num_nodes=num_nodes,
cloud_compute=cloud_compute,
**work_kwargs,
)