/
training.py
188 lines (159 loc) · 6.49 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from lightning_app import structures
from lightning_app.components.python import TracerPythonScript
from lightning_app.core.flow import LightningFlow
from lightning_app.storage.path import Path
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
_logger = Logger(__name__)
class PyTorchLightningScriptRunner(TracerPythonScript):
def __init__(
self,
script_path: str,
script_args: Optional[Union[list, str]] = None,
node_rank: int = 1,
num_nodes: int = 1,
sanity_serving: bool = False,
cloud_compute: Optional[CloudCompute] = None,
parallel: bool = True,
raise_exception: bool = True,
env: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(
script_path,
script_args,
raise_exception=raise_exception,
parallel=parallel,
cloud_compute=cloud_compute,
**kwargs,
)
self.node_rank = node_rank
self.num_nodes = num_nodes
self.best_model_path = None
self.best_model_score = None
self.monitor = None
self.sanity_serving = sanity_serving
self.has_finished = False
self.env = env
def configure_tracer(self):
from pytorch_lightning import Trainer
tracer = super().configure_tracer()
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
return tracer
def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs) -> None:
if not internal_urls:
# Note: This is called only once.
_logger.info(f"The node {self.node_rank} started !")
return None
if self.env:
os.environ.update(self.env)
distributed_env_vars = {
"MASTER_ADDR": internal_urls[0][0],
"MASTER_PORT": str(internal_urls[0][1]),
"NODE_RANK": str(self.node_rank),
"PL_TRAINER_NUM_NODES": str(self.num_nodes),
"PL_TRAINER_DEVICES": "auto",
"PL_TRAINER_ACCELERATOR": "auto",
}
os.environ.update(distributed_env_vars)
return super().run(**kwargs)
def on_after_run(self, script_globals):
from pytorch_lightning import Trainer
from pytorch_lightning.cli import LightningCLI
for v in script_globals.values():
if isinstance(v, LightningCLI):
trainer = v.trainer
break
elif isinstance(v, Trainer):
trainer = v
break
else:
raise RuntimeError("No trainer instance found.")
self.monitor = trainer.checkpoint_callback.monitor
if trainer.checkpoint_callback.best_model_score:
self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
else:
self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)
self.has_finished = True
def _trainer_init_pre_middleware(self, trainer, *args, **kwargs):
if self.node_rank != 0:
return {}, args, kwargs
from pytorch_lightning.serve import ServableModuleValidator
callbacks = kwargs.get("callbacks", [])
if self.sanity_serving:
callbacks = callbacks + [ServableModuleValidator()]
kwargs["callbacks"] = callbacks
return {}, args, kwargs
@property
def is_running_in_cloud(self) -> bool:
return "LIGHTNING_APP_STATE_URL" in os.environ
class LightningTrainerScript(LightningFlow):
def __init__(
self,
script_path: str,
script_args: Optional[Union[list, str]] = None,
num_nodes: int = 1,
cloud_compute: CloudCompute = CloudCompute("default"),
sanity_serving: bool = False,
script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
**script_runner_kwargs,
):
"""This component enables performing distributed multi-node multi-device training.
Example::
from lightning_app import LightningApp
from lightning_app.components.training import LightningTrainerScript
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
app = LightningApp(
LightningTrainerScript(
"train.py",
num_nodes=2,
cloud_compute=CloudCompute("gpu"),
),
)
Arguments:
script_path: Path to the script to be executed.
script_args: The arguments to be pass to the script.
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
sanity_serving: Whether to validate that the model correctly implements
the ServableModule API
"""
super().__init__()
self.script_path = script_path
self.script_args = script_args
self.num_nodes = num_nodes
self.sanity_serving = sanity_serving
self._script_runner = script_runner
self._script_runner_kwargs = script_runner_kwargs
self.ws = structures.List()
for node_rank in range(self.num_nodes):
self.ws.append(
self._script_runner(
script_path=self.script_path,
script_args=self.script_args,
cloud_compute=cloud_compute,
node_rank=node_rank,
sanity_serving=self.sanity_serving,
num_nodes=self.num_nodes,
**self._script_runner_kwargs,
)
)
def run(self, **run_kwargs):
for work in self.ws:
if all(w.internal_ip for w in self.ws):
internal_urls = [(w.internal_ip, w.port) for w in self.ws]
work.run(internal_urls=internal_urls, **run_kwargs)
if all(w.has_finished for w in self.ws):
for w in self.ws:
w.stop()
else:
work.run()
@property
def best_model_score(self) -> Optional[float]:
return self.ws[0].best_model_score
@property
def best_model_paths(self) -> List[Optional[Path]]:
return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]