-
Notifications
You must be signed in to change notification settings - Fork 618
/
printer.py
300 lines (243 loc) 路 9.09 KB
/
printer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Note: this is a helper printer class, this file might go away once we switch to rich console printing
from abc import abstractmethod
import itertools
import platform
import sys
from typing import Callable, List, Optional, Tuple, Union
import click
import wandb
from . import ipython, sparkline
# Follow the same logic as the python logging module
CRITICAL = 50
FATAL = CRITICAL
ERROR = 40
WARNING = 30
WARN = WARNING
INFO = 20
DEBUG = 10
NOTSET = 0
_level_to_name = {
CRITICAL: "CRITICAL",
ERROR: "ERROR",
WARNING: "WARNING",
INFO: "INFO",
DEBUG: "DEBUG",
NOTSET: "NOTSET",
}
_name_to_level = {
"CRITICAL": CRITICAL,
"FATAL": FATAL,
"ERROR": ERROR,
"WARN": WARNING,
"WARNING": WARNING,
"INFO": INFO,
"DEBUG": DEBUG,
"NOTSET": NOTSET,
}
class _Printer:
def sparklines(self, series: List[Union[int, float]]) -> Optional[str]:
# Only print sparklines if the terminal is utf-8
if wandb.util.is_unicode_safe(sys.stdout):
return sparkline.sparkify(series)
return None
def abort(
self,
) -> str:
return "Control-C" if platform.system() != "Windows" else "Ctrl-C"
def display(
self,
text: Union[str, List[str], Tuple[str]],
*,
level: Optional[Union[str, int]] = None,
off: Optional[bool] = None,
default_text: Optional[Union[str, List[str], Tuple[str]]] = None,
) -> None:
if off:
return
self._display(text, level=level, default_text=default_text)
@abstractmethod
def _display(
self,
text: Union[str, List[str], Tuple[str]],
*,
level: Optional[Union[str, int]] = None,
default_text: Optional[Union[str, List[str], Tuple[str]]] = None,
) -> None:
raise NotImplementedError
@staticmethod
def _sanitize_level(name_or_level: Optional[Union[str, int]]) -> int:
if isinstance(name_or_level, str):
try:
return _name_to_level[name_or_level.upper()]
except KeyError:
raise ValueError(
f"Unknown level name: {name_or_level}, supported levels: {_name_to_level.keys()}"
)
if isinstance(name_or_level, int):
return name_or_level
if name_or_level is None:
return INFO
raise ValueError(f"Unknown status level {name_or_level}")
@abstractmethod
def code(self, text: str) -> str:
raise NotImplementedError
@abstractmethod
def name(self, text: str) -> str:
raise NotImplementedError
@abstractmethod
def link(self, link: str, text: Optional[str] = None) -> str:
raise NotImplementedError
@abstractmethod
def emoji(self, name: str) -> str:
raise NotImplementedError
@abstractmethod
def status(self, text: str, failure: Optional[bool] = None) -> str:
raise NotImplementedError
@abstractmethod
def files(self, text: str) -> str:
raise NotImplementedError
@abstractmethod
def grid(self, rows: List[List[str]], title: Optional[str] = None) -> str:
raise NotImplementedError
@abstractmethod
def panel(self, columns: List[str]) -> str:
raise NotImplementedError
class PrinterTerm(_Printer):
def __init__(self) -> None:
super().__init__()
self._html = False
self._progress = itertools.cycle(["-", "\\", "|", "/"])
def _display(
self,
text: Union[str, List[str], Tuple[str]],
*,
level: Optional[Union[str, int]] = None,
default_text: Optional[Union[str, List[str], Tuple[str]]] = None,
) -> None:
text = "\n".join(text) if isinstance(text, (list, tuple)) else text
if default_text is not None:
default_text = (
"\n".join(default_text)
if isinstance(default_text, (list, tuple))
else default_text
)
text = text or default_text
self._display_fn_mapping(level)(text)
@staticmethod
def _display_fn_mapping(level: Optional[Union[str, int]]) -> Callable[[str], None]:
level = _Printer._sanitize_level(level)
if level >= CRITICAL:
return wandb.termerror
elif ERROR <= level < CRITICAL:
return wandb.termerror
elif WARNING <= level < ERROR:
return wandb.termwarn
elif INFO <= level < WARNING:
return wandb.termlog
elif DEBUG <= level < INFO:
return wandb.termlog
else:
return wandb.termlog
def progress_update(self, text: str, percentage: Optional[float] = None) -> None:
wandb.termlog(f"{next(self._progress)} {text}", newline=False)
def progress_close(self) -> None:
wandb.termlog(" " * 79)
def code(self, text: str) -> str:
ret: str = click.style(text, bold=True)
return ret
def name(self, text: str) -> str:
ret: str = click.style(text, fg="yellow")
return ret
def link(self, link: str, text: Optional[str] = None) -> str:
ret: str = click.style(link, fg="blue", underline=True)
return ret
def emoji(self, name: str) -> str:
emojis = dict()
if platform.system() != "Windows" and wandb.util.is_unicode_safe(sys.stdout):
emojis = dict(star="猸愶笍", broom="馃Ч", rocket="馃殌")
return emojis.get(name, "")
def status(self, text: str, failure: Optional[bool] = None) -> str:
color = "red" if failure else "green"
ret: str = click.style(text, fg=color)
return ret
def files(self, text: str) -> str:
ret: str = click.style(text, fg="magenta", bold=True)
return ret
def grid(self, rows: List[List[str]], title: Optional[str] = None) -> str:
max_len = max(len(row[0]) for row in rows)
format_row = " ".join(["{:>{max_len}}", "{}" * (len(rows[0]) - 1)])
grid = "\n".join([format_row.format(*row, max_len=max_len) for row in rows])
if title:
return f"{title}\n{grid}\n"
return f"{grid}\n"
def panel(self, columns: List[str]) -> str:
return "\n" + "\n".join(columns)
class PrinterJupyter(_Printer):
def __init__(self) -> None:
super().__init__()
self._html = True
self._progress = ipython.jupyter_progress_bar()
def _display(
self,
text: Union[str, List[str], Tuple[str]],
*,
level: Optional[Union[str, int]] = None,
default_text: Optional[Union[str, List[str], Tuple[str]]] = None,
) -> None:
text = "<br/>".join(text) if isinstance(text, (list, tuple)) else text
if default_text is not None:
default_text = (
"<br/>".join(default_text)
if isinstance(default_text, (list, tuple))
else default_text
)
text = text or default_text
self._display_fn_mapping(level)(text)
@staticmethod
def _display_fn_mapping(level: Optional[Union[str, int]]) -> Callable[[str], None]:
level = _Printer._sanitize_level(level)
if level >= CRITICAL:
return ipython.display_html
elif ERROR <= level < CRITICAL:
return ipython.display_html
elif WARNING <= level < ERROR:
return ipython.display_html
elif INFO <= level < WARNING:
return ipython.display_html
elif DEBUG <= level < INFO:
return ipython.display_html
else:
return ipython.display_html
def code(self, text: str) -> str:
return f"<code>{text}<code>"
def name(self, text: str) -> str:
return f'<strong style="color:#cdcd00">{text}</strong>'
def link(self, link: str, text: Optional[str] = None) -> str:
return f'<a href="{link}" target="_blank">{text or link}</a>'
def emoji(self, name: str) -> str:
return ""
def status(self, text: str, failure: Optional[bool] = None) -> str:
color = "red" if failure else "green"
return f'<strong style="color:{color}">{text}</strong>'
def files(self, text: str) -> str:
return f"<code>{text}</code>"
def progress_update(self, text: str, percent_done: float) -> None:
if self._progress:
self._progress.update(percent_done, text)
def progress_close(self) -> None:
if self._progress:
self._progress.close()
def grid(self, rows: List[List[str]], title: Optional[str] = None) -> str:
format_row = "".join(["<tr>", "<td>{}</td>" * len(rows[0]), "</tr>"])
grid = "".join([format_row.format(*row) for row in rows])
grid = f'<table class="wandb">{grid}</table>'
if title:
return f"<h3>{title}</h3><br/>{grid}<br/>"
return f"{grid}<br/>"
def panel(self, columns: List[str]) -> str:
row = "".join([f'<div class="wandb-col">{col}</div>' for col in columns])
return f'{ipython.TABLE_STYLES}<div class="wandb-row">{row}</div>'
def get_printer(_jupyter: Optional[bool] = None) -> Union[PrinterTerm, PrinterJupyter]:
if _jupyter and ipython.in_jupyter():
return PrinterJupyter()
return PrinterTerm()