-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
tracer.py
186 lines (144 loc) · 6.64 KB
/
tracer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import signal
import sys
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing_extensions import TypedDict
from lightning_app.core.work import LightningWork
from lightning_app.storage.drive import Drive
from lightning_app.storage.payload import Payload
from lightning_app.utilities.app_helpers import _collect_child_process_pids, Logger
from lightning_app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
from lightning_app.utilities.tracer import Tracer
logger = Logger(__name__)
class Code(TypedDict):
drive: Drive
name: str
class TracerPythonScript(LightningWork):
_start_method = "spawn"
def on_before_run(self):
"""Called before the python script is executed."""
def on_after_run(self, res: Any):
"""Called after the python script is executed."""
for name in self.outputs:
setattr(self, name, Payload(res[name]))
def configure_tracer(self) -> Tracer:
"""Override this hook to customize your tracer when running PythonScript."""
return Tracer()
def __init__(
self,
script_path: str,
script_args: Optional[Union[list, str]] = None,
outputs: Optional[List[str]] = None,
env: Optional[Dict] = None,
code: Optional[Code] = None,
**kwargs,
):
"""The TracerPythonScript class enables to easily run a python script.
When subclassing this class, you can configure your own :class:`~lightning_app.utilities.tracer.Tracer`
by :meth:`~lightning_app.components.python.tracer.TracerPythonScript.configure_tracer` method.
The tracer is quite a magical class. It enables you to inject code into a script execution without changing it.
Arguments:
script_path: Path of the python script to run.
script_path: The arguments to be passed to the script.
outputs: Collection of object names to collect after the script execution.
env: Environment variables to be passed to the script.
kwargs: LightningWork Keyword arguments.
Raises:
FileNotFoundError: If the provided `script_path` doesn't exists.
**How does it work?**
It works by executing the python script with python built-in `runpy
<https://docs.python.org/3/library/runpy.html>`_ run_path method.
This method takes any python globals before executing the script,
e.g., you can modify classes or function from the script.
.. doctest::
>>> from lightning_app.components.python import TracerPythonScript
>>> f = open("a.py", "w")
>>> f.write("print('Hello World !')")
22
>>> f.close()
>>> python_script = TracerPythonScript("a.py")
>>> python_script.run()
Hello World !
>>> os.remove("a.py")
In the example below, we subclass the :class:`~lightning_app.components.python.TracerPythonScript`
component and override its configure_tracer method.
Using the Tracer, we are patching the ``__init__`` method of the PyTorch Lightning Trainer.
Once the script starts running and if a Trainer is instantiated, the provided ``pre_fn`` is
called and we inject a Lightning callback.
This callback has a reference to the work and on every batch end, we are capturing the
trainer ``global_step`` and ``best_model_path``.
Even more interesting, this component works for ANY PyTorch Lightning script and
its state can be used in real time in a UI.
.. literalinclude:: ../../../examples/app_components/python/component_tracer.py
:language: python
Once implemented, this component can easily be integrated within a larger app
to execute a specific python script.
.. literalinclude:: ../../../examples/app_components/python/app.py
:language: python
"""
super().__init__(**kwargs)
self.script_path = str(script_path)
if isinstance(script_args, str):
script_args = script_args.split(" ")
self.script_args = script_args if script_args else []
self.original_args = deepcopy(self.script_args)
self.env = env
self.outputs = outputs or []
for name in self.outputs:
setattr(self, name, None)
self.params = None
self.drive = code.get("drive") if code else None
self.code_name = code.get("name") if code else None
self.restart_count = 0
def run(
self,
params: Optional[Dict[str, Any]] = None,
restart_count: Optional[int] = None,
code_dir: Optional[str] = ".",
**kwargs,
):
"""
Arguments:
params: A dictionary of arguments to be be added to script_args.
restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
code_dir: A path string determining where the source is extracted, default is current directory.
"""
if restart_count:
self.restart_count = restart_count
if params:
self.params = params
self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]
if self.drive:
assert self.code_name
if os.path.exists(self.code_name):
clean_tarfile(self.code_name, "r:gz")
if self.code_name in self.drive.list():
self.drive.get(self.code_name)
extract_tarfile(self.code_name, code_dir, "r:gz")
prev_cwd = os.getcwd()
os.chdir(code_dir)
if not os.path.exists(self.script_path):
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
init_globals = globals()
init_globals.update(kwargs)
self.on_before_run()
env_copy = os.environ.copy()
if self.env:
os.environ.update(self.env)
res = self._run_tracer(init_globals)
os.chdir(prev_cwd)
os.environ = env_copy
return self.on_after_run(res)
def _run_tracer(self, init_globals):
sys.argv = [self.script_path]
tracer = self.configure_tracer()
return tracer.trace(self.script_path, *self.script_args, init_globals=init_globals)
def on_exit(self):
for child_pid in _collect_child_process_pids(os.getpid()):
os.kill(child_pid, signal.SIGTERM)
@staticmethod
def _to_script_args(k: str, v: str) -> str:
return f"{k}={v}"
__all__ = ["TracerPythonScript"]