-
Notifications
You must be signed in to change notification settings - Fork 618
/
wandb_manager.py
187 lines (149 loc) 路 5.65 KB
/
wandb_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Manage wandb processes.
Create a grpc manager channel.
"""
import atexit
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import wandb
from wandb import env, trigger
from wandb.sdk.lib.exit_hooks import ExitHooks
from wandb.sdk.lib.import_hooks import unregister_all_post_import_hooks
from wandb.sdk.lib.proto_util import settings_dict_from_pbmap
if TYPE_CHECKING:
from wandb.sdk.service import service
from wandb.sdk.service.service_base import ServiceInterface
from wandb.sdk.wandb_settings import Settings
class _ManagerToken:
_version = "2"
_supported_transports = {"grpc", "tcp"}
_token_str: str
_pid: int
_transport: str
_host: str
_port: int
def __init__(self, token: str) -> None:
self._token_str = token
self._parse()
@classmethod
def from_environment(cls) -> Optional["_ManagerToken"]:
token = os.environ.get(env.SERVICE)
if not token:
return None
return cls(token=token)
@classmethod
def from_params(cls, transport: str, host: str, port: int) -> "_ManagerToken":
version = cls._version
pid = os.getpid()
token = "-".join([version, str(pid), transport, host, str(port)])
return cls(token=token)
def set_environment(self) -> None:
os.environ[env.SERVICE] = self._token_str
def _parse(self) -> None:
assert self._token_str
parts = self._token_str.split("-")
assert len(parts) == 5, f"token must have 5 parts: {parts}"
# TODO: make more robust?
version, pid_str, transport, host, port_str = parts
assert version == self._version
assert transport in self._supported_transports
self._pid = int(pid_str)
self._transport = transport
self._host = host
self._port = int(port_str)
def reset_environment(self) -> None:
os.environ.pop(env.SERVICE, None)
@property
def token(self) -> str:
return self._token_str
@property
def pid(self) -> int:
return self._pid
@property
def transport(self) -> str:
return self._transport
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
class _Manager:
_token: _ManagerToken
_atexit_lambda: Optional[Callable[[], None]]
_hooks: Optional[ExitHooks]
_settings: "Settings"
_service: "service._Service"
def __init__(self, settings: "Settings", _use_grpc: bool = False) -> None:
# TODO: warn if user doesnt have grpc installed
from wandb.sdk.service import service
self._settings = settings
self._atexit_lambda = None
self._hooks = None
self._service = service._Service(_use_grpc=_use_grpc)
token = _ManagerToken.from_environment()
if not token:
self._service.start()
host = "localhost"
if _use_grpc:
transport = "grpc"
port = self._service.grpc_port
else:
transport = "tcp"
port = self._service.sock_port
assert port
token = _ManagerToken.from_params(transport=transport, host=host, port=port)
token.set_environment()
self._atexit_setup()
self._token = token
port = self._token.port
svc_iface = self._get_service_interface()
svc_iface._svc_connect(port=port)
def _atexit_setup(self) -> None:
self._atexit_lambda = lambda: self._atexit_teardown()
self._hooks = ExitHooks()
self._hooks.hook()
atexit.register(self._atexit_lambda)
def _atexit_teardown(self) -> None:
trigger.call("on_finished")
exit_code = self._hooks.exit_code if self._hooks else 0
self._teardown(exit_code)
def _teardown(self, exit_code: int) -> None:
unregister_all_post_import_hooks()
if self._atexit_lambda:
atexit.unregister(self._atexit_lambda)
self._atexit_lambda = None
try:
self._inform_teardown(exit_code)
result = self._service.join()
if result and not self._settings._jupyter:
os._exit(result)
except Exception as e:
wandb.termlog(
f"While tearing down the service manager. The following error has occured: {e}",
repeat=False,
)
finally:
self._token.reset_environment()
def _get_service(self) -> "service._Service":
return self._service
def _get_service_interface(self) -> "ServiceInterface":
assert self._service
svc_iface = self._service.service_interface
assert svc_iface
return svc_iface
def _inform_init(self, settings: "Settings", run_id: str) -> None:
svc_iface = self._get_service_interface()
svc_iface._svc_inform_init(settings=settings, run_id=run_id)
def _inform_start(self, settings: "Settings", run_id: str) -> None:
svc_iface = self._get_service_interface()
svc_iface._svc_inform_start(settings=settings, run_id=run_id)
def _inform_attach(self, attach_id: str) -> Dict[str, Any]:
svc_iface = self._get_service_interface()
response = svc_iface._svc_inform_attach(attach_id=attach_id)
return settings_dict_from_pbmap(response._settings_map)
def _inform_finish(self, run_id: str = None) -> None:
svc_iface = self._get_service_interface()
svc_iface._svc_inform_finish(run_id=run_id)
def _inform_teardown(self, exit_code: int) -> None:
svc_iface = self._get_service_interface()
svc_iface._svc_inform_teardown(exit_code)