-
Notifications
You must be signed in to change notification settings - Fork 618
/
wandb_settings.py
1568 lines (1408 loc) 路 58.6 KB
/
wandb_settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import configparser
import enum
import getpass
import json
import multiprocessing
import os
import platform
import re
import socket
import sys
import tempfile
import time
from datetime import datetime
from distutils.util import strtobool
from functools import reduce
from typing import (
Any,
Callable,
Dict,
FrozenSet,
ItemsView,
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
no_type_check,
)
from urllib.parse import quote, urlencode, urlparse, urlsplit
import wandb
import wandb.env
from wandb import util
from wandb.apis.internal import Api
from wandb.errors import UsageError
from wandb.sdk.wandb_config import Config
from wandb.sdk.wandb_setup import _EarlyLogger
from .lib import apikey
from .lib.git import GitRepo
from .lib.ipython import _get_python_type
from .lib.runid import generate_id
if sys.version_info >= (3, 8):
from typing import get_args, get_origin, get_type_hints
elif sys.version_info >= (3, 7):
from typing_extensions import get_args, get_origin, get_type_hints
else:
def get_args(obj: Any) -> Optional[Any]:
return obj.__args__ if hasattr(obj, "__args__") else None
def get_origin(obj: Any) -> Optional[Any]:
return obj.__origin__ if hasattr(obj, "__origin__") else None
def get_type_hints(obj: Any) -> Dict[str, Any]:
return dict(obj.__annotations__) if hasattr(obj, "__annotations__") else dict()
def _get_wandb_dir(root_dir: str) -> str:
"""
Get the full path to the wandb directory.
The setting exposed to users as `dir=` or `WANDB_DIR` is the `root_dir`.
We add the `__stage_dir__` to it to get the full `wandb_dir`
"""
# We use the hidden version if it already exists, otherwise non-hidden.
if os.path.exists(os.path.join(root_dir, ".wandb")):
__stage_dir__ = ".wandb" + os.sep
else:
__stage_dir__ = "wandb" + os.sep
path = os.path.join(root_dir, __stage_dir__)
if not os.access(root_dir or ".", os.W_OK):
wandb.termwarn(
f"Path {path} wasn't writable, using system temp directory.",
repeat=False,
)
path = os.path.join(tempfile.gettempdir(), __stage_dir__ or ("wandb" + os.sep))
return os.path.expanduser(path)
# todo: should either return bool or error out. fix once confident.
def _str_as_bool(val: Union[str, bool]) -> bool:
"""
Parse a string as a bool.
"""
if isinstance(val, bool):
return val
try:
ret_val = bool(strtobool(str(val)))
return ret_val
except (AttributeError, ValueError):
pass
# todo: remove this and only raise error once we are confident.
wandb.termwarn(
f"Could not parse value {val} as a bool. ",
repeat=False,
)
raise UsageError(f"Could not parse value {val} as a bool.")
def _redact_dict(
d: Dict[str, Any],
unsafe_keys: Union[Set[str], FrozenSet[str]] = frozenset({"api_key"}),
redact_str: str = "***REDACTED***",
) -> Dict[str, Any]:
"""Redact a dict of unsafe values specified by their key."""
if not d or unsafe_keys.isdisjoint(d):
return d
safe_dict = d.copy()
safe_dict.update({k: redact_str for k in unsafe_keys.intersection(d)})
return safe_dict
def _get_program() -> Optional[Any]:
program = os.getenv(wandb.env.PROGRAM)
if program is not None:
return program
try:
import __main__
if __main__.__spec__ is None:
return __main__.__file__
# likely run as `python -m ...`
return f"-m {__main__.__spec__.name}"
except (ImportError, AttributeError):
return None
def _get_program_relpath_from_gitrepo(
program: str, _logger: Optional[_EarlyLogger] = None
) -> Optional[str]:
repo = GitRepo()
root = repo.root
if not root:
root = os.getcwd()
full_path_to_program = os.path.join(
root, os.path.relpath(os.getcwd(), root), program
)
if os.path.exists(full_path_to_program):
relative_path = os.path.relpath(full_path_to_program, start=root)
if "../" in relative_path:
if _logger is not None:
_logger.warning(f"Could not save program above cwd: {program}")
return None
return relative_path
if _logger is not None:
_logger.warning(f"Could not find program at {program}")
return None
@enum.unique
class Source(enum.IntEnum):
OVERRIDE: int = 0
BASE: int = 1 # todo: audit this
ORG: int = 2
ENTITY: int = 3
PROJECT: int = 4
USER: int = 5
SYSTEM: int = 6
WORKSPACE: int = 7
ENV: int = 8
SETUP: int = 9
LOGIN: int = 10
INIT: int = 11
SETTINGS: int = 12
ARGS: int = 13
RUN: int = 14
@enum.unique
class SettingsConsole(enum.IntEnum):
OFF = 0
WRAP = 1
REDIRECT = 2
WRAP_RAW = 3
WRAP_EMU = 4
class Property:
"""
A class to represent attributes (individual settings) of the Settings object.
- Encapsulates the logic of how to preprocess and validate values of settings
throughout the lifetime of a class instance.
- Allows for runtime modification of settings with hooks, e.g. in the case when
a setting depends on another setting.
- The update() method is used to update the value of a setting.
- The `is_policy` attribute determines the source priority when updating the property value.
E.g. if `is_policy` is True, the smallest `Source` value takes precedence.
"""
# todo: this is a temporary measure to bypass validation of the settings
# whose validation was not previously enforced to make sure we don't brake anything.
__strict_validate_settings = {
"project",
"start_method",
"mode",
"console",
"problem",
"anonymous",
"strict",
"silent",
"show_info",
"show_warnings",
"show_errors",
"base_url",
"login_timeout",
}
def __init__( # pylint: disable=unused-argument
self,
name: str,
value: Optional[Any] = None,
preprocessor: Union[Callable, Sequence[Callable], None] = None,
# validators allow programming by contract
validator: Union[Callable, Sequence[Callable], None] = None,
# runtime converter (hook): properties can be e.g. tied to other properties
hook: Union[Callable, Sequence[Callable], None] = None,
# always apply hook even if value is None. can be used to replace @property's
auto_hook: bool = False,
is_policy: bool = False,
frozen: bool = False,
source: int = Source.BASE,
**kwargs: Any,
):
self.name = name
self._preprocessor = preprocessor
self._validator = validator
self._hook = hook
self._auto_hook = auto_hook
self._is_policy = is_policy
self._source = source
# todo: this is a temporary measure to collect stats on failed preprocessing and validation
self.__failed_preprocessing: bool = False
self.__failed_validation: bool = False
# preprocess and validate value
self._value = self._validate(self._preprocess(value))
self.__frozen = frozen
@property
def value(self) -> Any:
"""Apply the runtime modifier(s) (if any) and return the value."""
_value = self._value
if (_value is not None or self._auto_hook) and self._hook is not None:
_hook = [self._hook] if callable(self._hook) else self._hook
for h in _hook:
_value = h(_value)
return _value
@property
def is_policy(self) -> bool:
return self._is_policy
@property
def source(self) -> int:
return self._source
def _preprocess(self, value: Any) -> Any:
if value is not None and self._preprocessor is not None:
_preprocessor = (
[self._preprocessor]
if callable(self._preprocessor)
else self._preprocessor
)
for p in _preprocessor:
try:
value = p(value)
except (UsageError, ValueError):
wandb.termwarn(
f"Unable to preprocess value for property {self.name}: {value}. "
"This will raise an error in the future.",
repeat=False,
)
self.__failed_preprocessing = True
break
return value
def _validate(self, value: Any) -> Any:
self.__failed_validation = False # todo: this is a temporary measure
if value is not None and self._validator is not None:
_validator = (
[self._validator] if callable(self._validator) else self._validator
)
for v in _validator:
if not v(value):
# todo: this is a temporary measure to bypass validation of certain settings.
# remove this once we are confident
if self.name in self.__strict_validate_settings:
raise ValueError(
f"Invalid value for property {self.name}: {value}"
)
else:
wandb.termwarn(
f"Invalid value for property {self.name}: {value}. "
"This will raise an error in the future.",
repeat=False,
)
self.__failed_validation = True
break
return value
def update(self, value: Any, source: int = Source.OVERRIDE) -> None:
"""Update the value of the property."""
if self.__frozen:
raise TypeError("Property object is frozen")
# - always update value if source == Source.OVERRIDE
# - if not previously overridden:
# - update value if source is lower than or equal to current source and property is policy
# - update value if source is higher than or equal to current source and property is not policy
if (
(source == Source.OVERRIDE)
or (
self._is_policy
and self._source != Source.OVERRIDE
and source <= self._source
)
or (
not self._is_policy
and self._source != Source.OVERRIDE
and source >= self._source
)
):
# self.__dict__["_value"] = self._validate(self._preprocess(value))
self._value = self._validate(self._preprocess(value))
self._source = source
def __setattr__(self, key: str, value: Any) -> None:
if "_Property__frozen" in self.__dict__ and self.__frozen:
raise TypeError(f"Property object {self.name} is frozen")
if key == "value":
raise AttributeError("Use update() to update property value")
self.__dict__[key] = value
def __str__(self) -> str:
return f"'{self.value}'" if isinstance(self.value, str) else f"{self.value}"
def __repr__(self) -> str:
return (
f"<Property {self.name}: value={self.value} "
f"_value={self._value} source={self._source} is_policy={self._is_policy}>"
)
# return f"<Property {self.name}: value={self.value}>"
# return self.__dict__.__repr__()
class Settings:
"""
Settings for the wandb client.
"""
# settings are declared as class attributes for static type checking purposes
# and to help with IDE autocomplete.
_args: Sequence[str]
_cli_only_mode: bool # Avoid running any code specific for runs
_colab: bool
_config_dict: Config
_console: SettingsConsole
_cuda: str
_disable_meta: bool
_disable_stats: bool
_disable_viewer: bool # Prevent early viewer query
_except_exit: bool
_executable: str
_internal_check_process: Union[int, float]
_internal_queue_timeout: Union[int, float]
_jupyter: bool
_jupyter_name: str
_jupyter_path: str
_jupyter_root: str
_kaggle: bool
_live_policy_rate_limit: int
_live_policy_wait_time: int
_log_level: int
_noop: bool
_offline: bool
_os: str
_platform: str
_python: str
_require_service: str
_runqueue_item_id: str
_save_requirements: bool
_service_transport: str
_start_datetime: datetime
_start_time: float
_stats_pid: int # (internal) base pid for system stats
_stats_sample_rate_seconds: float
_stats_samples_to_average: int
_stats_join_assets: bool # join metrics from different assets before sending to backend
_tmp_code_dir: str
_tracelog: str
_unsaved_keys: Sequence[str]
_windows: bool
allow_val_change: bool
anonymous: str
api_key: str
base_url: str # The base url for the wandb api
code_dir: str
config_paths: Sequence[str]
console: str
deployment: str
disable_code: bool
disable_git: bool
disable_hints: bool
disabled: bool # Alias for mode=dryrun, not supported yet
docker: str
email: str
enable_job_creation: bool
entity: str
files_dir: str
force: bool
git_commit: str
git_remote: str
git_remote_url: str
git_root: str
heartbeat_seconds: int
host: str
ignore_globs: Tuple[str]
init_timeout: int
is_local: bool
label_disable: bool
launch: bool
launch_config_path: str
log_dir: str
log_internal: str
log_symlink_internal: str
log_symlink_user: str
log_user: str
login_timeout: float
magic: Union[str, bool, dict]
mode: str
notebook_name: str
problem: str
program: str
program_relpath: str
project: str
project_url: str
quiet: bool
reinit: bool
relogin: bool
resume: Union[str, int, bool]
resume_fname: str
resumed: bool # indication from the server about the state of the run (different from resume - user provided flag)
root_dir: str
run_group: str
run_id: str
run_job_type: str
run_mode: str
run_name: str
run_notes: str
run_tags: Tuple[str]
run_url: str
sagemaker_disable: bool
save_code: bool
settings_system: str
settings_workspace: str
show_colors: bool
show_emoji: bool
show_errors: bool
show_info: bool
show_warnings: bool
silent: bool
start_method: str
strict: bool
summary_errors: int
summary_timeout: int
summary_warnings: int
sweep_id: str
sweep_param_path: str
sweep_url: str
symlink: bool
sync_dir: str
sync_file: str
sync_symlink_latest: str
system_sample: int
system_sample_seconds: int
timespec: str
tmp_dir: str
username: str
wandb_dir: str
table_raise_on_max_row_limit_exceeded: bool
def _default_props(self) -> Dict[str, Dict[str, Any]]:
"""
Helper method that is used in `__init__` together with the class attributes
to initialize instance attributes (individual settings) as Property objects.
Note that key names must be the same as the class attribute names.
"""
return dict(
_disable_meta={"preprocessor": _str_as_bool},
_disable_stats={"preprocessor": _str_as_bool},
_disable_viewer={"preprocessor": _str_as_bool},
_colab={
"hook": lambda _: "google.colab" in sys.modules,
"auto_hook": True,
},
_console={"hook": lambda _: self._convert_console(), "auto_hook": True},
_internal_check_process={"value": 8},
_internal_queue_timeout={"value": 2},
_jupyter={
"hook": lambda _: str(_get_python_type()) != "python",
"auto_hook": True,
},
_kaggle={"hook": lambda _: util._is_likely_kaggle(), "auto_hook": True},
_noop={"hook": lambda _: self.mode == "disabled", "auto_hook": True},
_offline={
"hook": (
lambda _: True
if self.disabled or (self.mode in ("dryrun", "offline"))
else False
),
"auto_hook": True,
},
_platform={"value": util.get_platform_name()},
_save_requirements={"value": True, "preprocessor": _str_as_bool},
_stats_sample_rate_seconds={"value": 2.0, "preprocessor": float},
_stats_samples_to_average={"value": 15},
_stats_join_assets={"value": True, "preprocessor": _str_as_bool},
_tmp_code_dir={
"value": "code",
"hook": lambda x: self._path_convert(self.tmp_dir, x),
},
_windows={
"hook": lambda _: platform.system() == "Windows",
"auto_hook": True,
},
anonymous={"validator": self._validate_anonymous},
api_key={"validator": self._validate_api_key},
base_url={
"value": "https://api.wandb.ai",
"preprocessor": lambda x: str(x).strip().rstrip("/"),
"validator": self._validate_base_url,
},
console={"value": "auto", "validator": self._validate_console},
deployment={
"hook": lambda _: "local" if self.is_local else "cloud",
"auto_hook": True,
},
disable_code={"preprocessor": _str_as_bool},
disable_hints={"preprocessor": _str_as_bool},
disable_git={"preprocessor": _str_as_bool},
disabled={"value": False, "preprocessor": _str_as_bool},
enable_job_creation={"preprocessor": _str_as_bool},
files_dir={
"value": "files",
"hook": lambda x: self._path_convert(
self.wandb_dir, f"{self.run_mode}-{self.timespec}-{self.run_id}", x
),
},
force={"preprocessor": _str_as_bool},
git_remote={"value": "origin"},
heartbeat_seconds={"value": 30},
ignore_globs={
"value": tuple(),
"preprocessor": lambda x: tuple(x) if not isinstance(x, tuple) else x,
},
init_timeout={"value": 60, "preprocessor": lambda x: int(x)},
is_local={
"hook": (
lambda _: self.base_url != "https://api.wandb.ai"
if self.base_url is not None
else False
),
"auto_hook": True,
},
label_disable={"preprocessor": _str_as_bool},
launch={"preprocessor": _str_as_bool},
log_dir={
"value": "logs",
"hook": lambda x: self._path_convert(
self.wandb_dir, f"{self.run_mode}-{self.timespec}-{self.run_id}", x
),
},
log_internal={
"value": "debug-internal.log",
"hook": lambda x: self._path_convert(self.log_dir, x),
},
log_symlink_internal={
"value": "debug-internal.log",
"hook": lambda x: self._path_convert(self.wandb_dir, x),
},
log_symlink_user={
"value": "debug.log",
"hook": lambda x: self._path_convert(self.wandb_dir, x),
},
log_user={
"value": "debug.log",
"hook": lambda x: self._path_convert(self.log_dir, x),
},
login_timeout={"preprocessor": lambda x: float(x)},
mode={"value": "online", "validator": self._validate_mode},
problem={"value": "fatal", "validator": self._validate_problem},
project={"validator": self._validate_project},
project_url={"hook": lambda _: self._project_url(), "auto_hook": True},
quiet={"preprocessor": _str_as_bool},
reinit={"preprocessor": _str_as_bool},
relogin={"preprocessor": _str_as_bool},
resume_fname={
"value": "wandb-resume.json",
"hook": lambda x: self._path_convert(self.wandb_dir, x),
},
resumed={"value": "False", "preprocessor": _str_as_bool},
root_dir={
"preprocessor": lambda x: str(x),
"value": os.path.abspath(os.getcwd()),
},
run_mode={
"hook": lambda _: "offline-run" if self._offline else "run",
"auto_hook": True,
},
run_tags={
"preprocessor": lambda x: tuple(x) if not isinstance(x, tuple) else x,
},
run_url={"hook": lambda _: self._run_url(), "auto_hook": True},
sagemaker_disable={"preprocessor": _str_as_bool},
save_code={"preprocessor": _str_as_bool},
settings_system={
"value": os.path.join("~", ".config", "wandb", "settings"),
"hook": lambda x: self._path_convert(x),
},
settings_workspace={
"value": "settings",
"hook": lambda x: self._path_convert(self.wandb_dir, x),
},
show_colors={"preprocessor": _str_as_bool},
show_emoji={"preprocessor": _str_as_bool},
show_errors={"value": "True", "preprocessor": _str_as_bool},
show_info={"value": "True", "preprocessor": _str_as_bool},
show_warnings={"value": "True", "preprocessor": _str_as_bool},
silent={"value": "False", "preprocessor": _str_as_bool},
start_method={"validator": self._validate_start_method},
strict={"preprocessor": _str_as_bool},
summary_timeout={"value": 60, "preprocessor": lambda x: int(x)},
summary_warnings={
"value": 5,
"preprocessor": lambda x: int(x),
"is_policy": True,
},
sweep_url={"hook": lambda _: self._sweep_url(), "auto_hook": True},
symlink={"preprocessor": _str_as_bool},
sync_dir={
"hook": [
lambda _: self._path_convert(
self.wandb_dir, f"{self.run_mode}-{self.timespec}-{self.run_id}"
)
],
"auto_hook": True,
},
sync_file={
"hook": lambda _: self._path_convert(
self.sync_dir, f"run-{self.run_id}.wandb"
),
"auto_hook": True,
},
sync_symlink_latest={
"value": "latest-run",
"hook": lambda x: self._path_convert(self.wandb_dir, x),
},
system_sample={"value": 15},
system_sample_seconds={"value": 2},
table_raise_on_max_row_limit_exceeded={
"value": False,
"preprocessor": _str_as_bool,
},
timespec={
"hook": (
lambda _: (
datetime.strftime(self._start_datetime, "%Y%m%d_%H%M%S")
if self._start_datetime
else None
)
),
"auto_hook": True,
},
tmp_dir={
"value": "tmp",
"hook": lambda x: (
self._path_convert(
self.wandb_dir,
f"{self.run_mode}-{self.timespec}-{self.run_id}",
x,
)
or tempfile.gettempdir()
),
},
wandb_dir={
"hook": lambda _: _get_wandb_dir(self.root_dir or ""),
"auto_hook": True,
},
)
# helper methods for validating values
@staticmethod
def _validator_factory(hint: Any) -> Callable[[Any], bool]:
"""
Factory for type validators, given a type hint:
Convert the type hint of a setting into a function
that checks if the argument is of the correct type
"""
origin, args = get_origin(hint), get_args(hint)
def helper(x: Any) -> bool:
if origin is None:
return isinstance(x, hint)
elif origin is Union:
return isinstance(x, args) if args is not None else True
else:
return (
isinstance(x, origin) and all(isinstance(y, args) for y in x)
if args is not None
else isinstance(x, origin)
)
return helper
@staticmethod
def _validate_mode(value: str) -> bool:
choices: Set[str] = {"dryrun", "run", "offline", "online", "disabled"}
if value not in choices:
raise UsageError(f"Settings field `mode`: '{value}' not in {choices}")
return True
@staticmethod
def _validate_project(value: Optional[str]) -> bool:
invalid_chars_list = list("/\\#?%:")
if value is not None:
if len(value) > 128:
raise UsageError(
f'Invalid project name "{value}": exceeded 128 characters'
)
invalid_chars = {char for char in invalid_chars_list if char in value}
if invalid_chars:
raise UsageError(
f'Invalid project name "{value}": '
f"cannot contain characters \"{','.join(invalid_chars_list)}\", "
f"found \"{','.join(invalid_chars)}\""
)
return True
@staticmethod
def _validate_start_method(value: str) -> bool:
available_methods = ["thread"]
if hasattr(multiprocessing, "get_all_start_methods"):
available_methods += multiprocessing.get_all_start_methods()
if value not in available_methods:
raise UsageError(
f"Settings field `start_method`: '{value}' not in {available_methods}"
)
return True
@staticmethod
def _validate_console(value: str) -> bool:
# choices = {"auto", "redirect", "off", "file", "iowrap", "notebook"}
choices: Set[str] = {
"auto",
"redirect",
"off",
"wrap",
# internal console states
"wrap_emu",
"wrap_raw",
}
if value not in choices:
# do not advertise internal console states
choices -= {"wrap_emu", "wrap_raw"}
raise UsageError(f"Settings field `console`: '{value}' not in {choices}")
return True
@staticmethod
def _validate_problem(value: str) -> bool:
choices: Set[str] = {"fatal", "warn", "silent"}
if value not in choices:
raise UsageError(f"Settings field `problem`: '{value}' not in {choices}")
return True
@staticmethod
def _validate_anonymous(value: str) -> bool:
choices: Set[str] = {"allow", "must", "never", "false", "true"}
if value not in choices:
raise UsageError(f"Settings field `anonymous`: '{value}' not in {choices}")
return True
@staticmethod
def _validate_api_key(value: str) -> bool:
if len(value) > len(value.strip()):
raise UsageError("API key cannot start or end with whitespace")
# if value.startswith("local") and not self.is_local:
# raise UsageError(
# "Attempting to use a local API key to connect to https://api.wandb.ai"
# )
# todo: move here the logic from sdk/lib/apikey.py
return True
@staticmethod
def _validate_base_url(value: Optional[str]) -> bool:
"""
Validate the base url of the wandb server.
param value: URL to validate
Based on the Django URLValidator, but with a few additional checks.
Copyright (c) Django Software Foundation and individual contributors.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of Django nor the names of its contributors may be used
to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
if value is None:
return True
ul = "\u00a1-\uffff" # Unicode letters range (must not be a raw string).
# IP patterns
ipv4_re = (
r"(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)"
r"(?:\.(?:0|25[0-5]|2[0-4][0-9]|1[0-9]?[0-9]?|[1-9][0-9]?)){3}"
)
ipv6_re = r"\[[0-9a-f:.]+\]" # (simple regex, validated later)
# Host patterns
hostname_re = (
r"[a-z" + ul + r"0-9](?:[a-z" + ul + r"0-9-]{0,61}[a-z" + ul + r"0-9])?"
)
# Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1
domain_re = r"(?:\.(?!-)[a-z" + ul + r"0-9-]{1,63}(?<!-))*"
tld_re = (
r"\." # dot
r"(?!-)" # can't start with a dash
r"(?:[a-z" + ul + "-]{2,63}" # domain label
r"|xn--[a-z0-9]{1,59})" # or punycode label
r"(?<!-)" # can't end with a dash
r"\.?" # may have a trailing dot
)
# host_re = "(" + hostname_re + domain_re + tld_re + "|localhost)"
# todo?: allow hostname to be just a hostname (no tld)?
host_re = "(" + hostname_re + domain_re + f"({tld_re})?" + "|localhost)"
regex = re.compile(
r"^(?:[a-z0-9.+-]*)://" # scheme is validated separately
r"(?:[^\s:@/]+(?::[^\s:@/]*)?@)?" # user:pass authentication
r"(?:" + ipv4_re + "|" + ipv6_re + "|" + host_re + ")"
r"(?::[0-9]{1,5})?" # port
r"(?:[/?#][^\s]*)?" # resource path
r"\Z",
re.IGNORECASE,
)
schemes = {"http", "https"}
unsafe_chars = frozenset("\t\r\n")
scheme = value.split("://")[0].lower()
split_url = urlsplit(value)
parsed_url = urlparse(value)
if re.match(r".*wandb\.ai[^\.]*$", value) and "api." not in value:
# user might guess app.wandb.ai or wandb.ai is the default cloud server
raise UsageError(
f"{value} is not a valid server address, did you mean https://api.wandb.ai?"
)
elif re.match(r".*wandb\.ai[^\.]*$", value) and scheme != "https":
raise UsageError("http is not secure, please use https://api.wandb.ai")
elif parsed_url.netloc == "":
raise UsageError(f"Invalid URL: {value}")
elif unsafe_chars.intersection(value):
raise UsageError("URL cannot contain unsafe characters")
elif scheme not in schemes:
raise UsageError("URL must start with `http(s)://`")
elif not regex.search(value):
raise UsageError(f"{value} is not a valid server address")
elif split_url.hostname is None or len(split_url.hostname) > 253:
raise UsageError("hostname is invalid")
return True
# other helper methods
@staticmethod
def _path_convert(*args: str) -> str:
"""
Join path and apply os.path.expanduser to it.
"""
return os.path.expanduser(os.path.join(*args))
def _convert_console(self) -> SettingsConsole:
convert_dict: Dict[str, SettingsConsole] = dict(
off=SettingsConsole.OFF,
wrap=SettingsConsole.WRAP,
wrap_raw=SettingsConsole.WRAP_RAW,
wrap_emu=SettingsConsole.WRAP_EMU,
redirect=SettingsConsole.REDIRECT,
)
console: str = str(self.console)
if console == "auto":
if (
self._jupyter
or (self.start_method == "thread")
or self._require_service
or self._windows
):
console = "wrap"
else:
console = "redirect"
convert: SettingsConsole = convert_dict[console]
return convert
def _get_url_query_string(self) -> str:
# TODO(settings) use `wandb_setting` (if self.anonymous != "true":)
if Api().settings().get("anonymous") != "true":
return ""
api_key = apikey.api_key(settings=self)
return f"?{urlencode({'apiKey': api_key})}"
def _project_url_base(self) -> str:
if not all([self.entity, self.project]):
return ""
app_url = wandb.util.app_url(self.base_url)
return f"{app_url}/{quote(self.entity)}/{quote(self.project)}"
def _project_url(self) -> str:
project_url = self._project_url_base()
if not project_url:
return ""
query = self._get_url_query_string()
return f"{project_url}{query}"
def _run_url(self) -> str:
"""
Return the run url.
"""
project_url = self._project_url_base()
if not all([project_url, self.run_id]):
return ""
query = self._get_url_query_string()
return f"{project_url}/runs/{quote(self.run_id)}{query}"
def _set_run_start_time(self, source: int = Source.BASE) -> None:
"""
Set the time stamps for the settings.
Called once the run is initialized.
"""
time_stamp: float = time.time()
datetime_now: datetime = datetime.fromtimestamp(time_stamp)
object.__setattr__(self, "_Settings_start_datetime", datetime_now)
object.__setattr__(self, "_Settings_start_time", time_stamp)
self.update(
_start_datetime=datetime_now,
_start_time=time_stamp,
source=source,
)
def _sweep_url(self) -> str:
"""
Return the sweep url.
"""
project_url = self._project_url_base()
if not all([project_url, self.sweep_id]):
return ""
query = self._get_url_query_string()
return f"{project_url}/sweeps/{quote(self.sweep_id)}{query}"
def __init__(self, **kwargs: Any) -> None: