From 1439824992db7acb83046159faca5e62c7d6a69c Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Fri, 19 Aug 2022 16:24:17 -0700 Subject: [PATCH 1/5] Sanitize numpy generics in keys --- tests/unit_tests/test_util.py | 35 +++++++++++++++++++ wandb/util.py | 65 ++++++++++++++++++++++++++++------- 2 files changed, 87 insertions(+), 13 deletions(-) diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index cc8523144aa..f08530f3002 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -688,3 +688,38 @@ def test_resolve_aliases(): aliases = util._resolve_aliases("boom") assert aliases == ["boom", "latest"] + + +# Compute recursive dicts for tests +d_recursive1i = {1: 2, 3: {4: 5}} +d_recursive1i["_"] = d_recursive1i +d_recursive2i = {1: 2, 3: {np.int64(44): 5}} +d_recursive2i["_"] = d_recursive2i +d_recursive2o = {1: 2, 3: {44: 5}} +d_recursive2o["_"] = d_recursive2o + + +@pytest.mark.parametrize( + "dict_input, dict_output", + [ + ({}, None), + ({1: 2}, None), + ({1: np.int64(3)}, None), # dont care about values + ({np.int64(3): 4}, {3: 4}), # top-level + ({1: {np.int64(3): 4}}, {1: {3: 4}}), # nested key + ({1: {np.int32(2): 4}}, {1: {2: 4}}), # nested key + (d_recursive1i, None), # recursive, no numpy + (d_recursive2i, d_recursive2o), # recursive, numpy + ], +) +def test_sanitize_numpy_keys(dict_input, dict_output): + dict_output = dict_output.copy() if dict_output is not None else None + output, converted = util._sanitize_numpy_keys(dict_input) + assert converted == (dict_output is not None) + + # pytest assert cant handle recursive dicts + if dict_output and "_" in dict_output: + output.pop("_") + dict_output.pop("_") + + assert output == (dict_output or dict_input) diff --git a/wandb/util.py b/wandb/util.py index 8b02d2e03a6..af8987c64f4 100644 --- a/wandb/util.py +++ b/wandb/util.py @@ -40,6 +40,7 @@ List, Mapping, Optional, + Set, Sequence, TextIO, Tuple, @@ -573,6 +574,54 @@ def matplotlib_contains_images(obj: Any) -> bool: return any(len(ax.images) > 0 for ax in obj.axes) +def _numpy_generic_convert(obj: Any) -> Any: + obj = obj.item() + if isinstance(obj, float) and math.isnan(obj): + obj = None + elif isinstance(obj, np.generic) and ( + obj.dtype.kind == "f" or obj.dtype == "bfloat16" + ): + # obj is a numpy float with precision greater than that of native python float + # (i.e., float96 or float128) or it is of custom type such as bfloat16. + # in these cases, obj.item() does not return a native + # python float (in the first case - to avoid loss of precision, + # so we need to explicitly cast this down to a 64bit float) + obj = float(obj) + return obj + + +def _find_all_matching_keys( + d: Dict, + check: Callable[[Any], bool], + visited: Set[int] = None, + path: Tuple[Any, ...] = (), +) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]: + if visited is None: + visited = set() + me = id(d) + if me not in visited: + visited.add(me) + for key, value in d.items(): + if check(key): + yield path, key + if isinstance(value, dict): + yield from _find_all_matching_keys( + value, check, visited=visited, path=tuple(list(path) + [key]) + ) + + +def _sanitize_numpy_keys(d: Dict) -> Tuple[Dict, bool]: + np_keys = list(_find_all_matching_keys(d, lambda k: isinstance(k, np.generic))) + if not np_keys: + return d, False + for path, key in np_keys: + ptr = d + for k in path: + ptr = ptr[k] + ptr[_numpy_generic_convert(key)] = ptr.pop(key) + return d, True + + def json_friendly( # noqa: C901 obj: Any, ) -> Union[Tuple[Any, bool], Tuple[Union[None, str, float], bool]]: # noqa: C901 @@ -612,19 +661,7 @@ def json_friendly( # noqa: C901 elif obj.size <= 32: obj = obj.tolist() elif np and isinstance(obj, np.generic): - obj = obj.item() - if isinstance(obj, float) and math.isnan(obj): - obj = None - elif isinstance(obj, np.generic) and ( - obj.dtype.kind == "f" or obj.dtype == "bfloat16" - ): - # obj is a numpy float with precision greater than that of native python float - # (i.e., float96 or float128) or it is of custom type such as bfloat16. - # in these cases, obj.item() does not return a native - # python float (in the first case - to avoid loss of precision, - # so we need to explicitly cast this down to a 64bit float) - obj = float(obj) - + obj = _numpy_generic_convert(obj) elif isinstance(obj, bytes): obj = obj.decode("utf-8") elif isinstance(obj, (datetime, date)): @@ -637,6 +674,8 @@ def json_friendly( # noqa: C901 ) elif isinstance(obj, float) and math.isnan(obj): obj = None + elif isinstance(obj, dict) and np: + obj, converted = _sanitize_numpy_keys(obj) else: converted = False if getsizeof(obj) > VALUE_BYTES_LIMIT: From d220872102ef2d2d5e3d1726da2f3d94dc150616 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Fri, 19 Aug 2022 16:58:10 -0700 Subject: [PATCH 2/5] add yea test --- .../t0_main/core/t7_config_numpy_keys.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/functional_tests/t0_main/core/t7_config_numpy_keys.py diff --git a/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py new file mode 100644 index 00000000000..6f95929a5f4 --- /dev/null +++ b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +""" +Use numpy in nested dict keys + +--- +plugin: + - wandb +assert: + - :wandb:runs_len: 1 + - :wandb:runs[0][config][this]: 2 + - :wandb:runs[0][config][ok]: {"3": 4} + - :wandb:runs[0][config][deeper][again]: {"9": 34} + - :wandb:runs[0][config][bad]: {"22": 4} + - :wandb:runs[0][exitcode]: 0 +""" +import wandb +import numpy as np + +wandb.init( + config={ + "this": 2, + "ok": {3: 4}, + "deeper": {"again": {9: 34}}, + "bad": {np.int64(22): 4}, + } +) From e4527407132008b7822b869f11d549f8b877cbc1 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Fri, 19 Aug 2022 17:08:46 -0700 Subject: [PATCH 3/5] add docstring --- wandb/util.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/wandb/util.py b/wandb/util.py index af8987c64f4..1d8ec7d61e7 100644 --- a/wandb/util.py +++ b/wandb/util.py @@ -592,21 +592,32 @@ def _numpy_generic_convert(obj: Any) -> Any: def _find_all_matching_keys( d: Dict, - check: Callable[[Any], bool], + match_fn: Callable[[Any], bool], visited: Set[int] = None, - path: Tuple[Any, ...] = (), + key_path: Tuple[Any, ...] = (), ) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]: + """Recursively find all keys that satisfies a match function. + + Args: + d: The dict to search. + match_fn: The function to determine if the key is a match. + visited: Keep track of visited nodes so we dont recurse forever. + key_path: Keep track of all the keys to get to the current node. + Yields: + (key_path, key): The location where the key was found, and the key + """ + if visited is None: visited = set() me = id(d) if me not in visited: visited.add(me) for key, value in d.items(): - if check(key): - yield path, key + if match_fn(key): + yield key_path, key if isinstance(value, dict): yield from _find_all_matching_keys( - value, check, visited=visited, path=tuple(list(path) + [key]) + value, match_fn, visited=visited, key_path=tuple(list(key_path) + [key]) ) @@ -614,9 +625,9 @@ def _sanitize_numpy_keys(d: Dict) -> Tuple[Dict, bool]: np_keys = list(_find_all_matching_keys(d, lambda k: isinstance(k, np.generic))) if not np_keys: return d, False - for path, key in np_keys: + for key_path, key in np_keys: ptr = d - for k in path: + for k in key_path: ptr = ptr[k] ptr[_numpy_generic_convert(key)] = ptr.pop(key) return d, True From a0a25e7702efd9818e24881e66aad2856655212f Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Fri, 19 Aug 2022 17:09:03 -0700 Subject: [PATCH 4/5] update --- wandb/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/wandb/util.py b/wandb/util.py index 1d8ec7d61e7..d995283727c 100644 --- a/wandb/util.py +++ b/wandb/util.py @@ -617,7 +617,10 @@ def _find_all_matching_keys( yield key_path, key if isinstance(value, dict): yield from _find_all_matching_keys( - value, match_fn, visited=visited, key_path=tuple(list(key_path) + [key]) + value, + match_fn, + visited=visited, + key_path=tuple(list(key_path) + [key]), ) From 419867f2a29def0fea820882d1dfe41263388374 Mon Sep 17 00:00:00 2001 From: Jeff Raubitschek Date: Fri, 19 Aug 2022 17:13:23 -0700 Subject: [PATCH 5/5] fix isort --- tests/functional_tests/t0_main/core/t7_config_numpy_keys.py | 2 +- wandb/util.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py index 6f95929a5f4..fbcbe6c0d5d 100644 --- a/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py +++ b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py @@ -13,8 +13,8 @@ - :wandb:runs[0][config][bad]: {"22": 4} - :wandb:runs[0][exitcode]: 0 """ -import wandb import numpy as np +import wandb wandb.init( config={ diff --git a/wandb/util.py b/wandb/util.py index d995283727c..377222cf6b2 100644 --- a/wandb/util.py +++ b/wandb/util.py @@ -40,8 +40,8 @@ List, Mapping, Optional, - Set, Sequence, + Set, TextIO, Tuple, Type,