diff --git a/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py new file mode 100644 index 00000000000..fbcbe6c0d5d --- /dev/null +++ b/tests/functional_tests/t0_main/core/t7_config_numpy_keys.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +""" +Use numpy in nested dict keys + +--- +plugin: + - wandb +assert: + - :wandb:runs_len: 1 + - :wandb:runs[0][config][this]: 2 + - :wandb:runs[0][config][ok]: {"3": 4} + - :wandb:runs[0][config][deeper][again]: {"9": 34} + - :wandb:runs[0][config][bad]: {"22": 4} + - :wandb:runs[0][exitcode]: 0 +""" +import numpy as np +import wandb + +wandb.init( + config={ + "this": 2, + "ok": {3: 4}, + "deeper": {"again": {9: 34}}, + "bad": {np.int64(22): 4}, + } +) diff --git a/tests/unit_tests/test_util.py b/tests/unit_tests/test_util.py index cc8523144aa..f08530f3002 100644 --- a/tests/unit_tests/test_util.py +++ b/tests/unit_tests/test_util.py @@ -688,3 +688,38 @@ def test_resolve_aliases(): aliases = util._resolve_aliases("boom") assert aliases == ["boom", "latest"] + + +# Compute recursive dicts for tests +d_recursive1i = {1: 2, 3: {4: 5}} +d_recursive1i["_"] = d_recursive1i +d_recursive2i = {1: 2, 3: {np.int64(44): 5}} +d_recursive2i["_"] = d_recursive2i +d_recursive2o = {1: 2, 3: {44: 5}} +d_recursive2o["_"] = d_recursive2o + + +@pytest.mark.parametrize( + "dict_input, dict_output", + [ + ({}, None), + ({1: 2}, None), + ({1: np.int64(3)}, None), # dont care about values + ({np.int64(3): 4}, {3: 4}), # top-level + ({1: {np.int64(3): 4}}, {1: {3: 4}}), # nested key + ({1: {np.int32(2): 4}}, {1: {2: 4}}), # nested key + (d_recursive1i, None), # recursive, no numpy + (d_recursive2i, d_recursive2o), # recursive, numpy + ], +) +def test_sanitize_numpy_keys(dict_input, dict_output): + dict_output = dict_output.copy() if dict_output is not None else None + output, converted = util._sanitize_numpy_keys(dict_input) + assert converted == (dict_output is not None) + + # pytest assert cant handle recursive dicts + if dict_output and "_" in dict_output: + output.pop("_") + dict_output.pop("_") + + assert output == (dict_output or dict_input) diff --git a/wandb/util.py b/wandb/util.py index 200dcee5c9f..70e898b1742 100644 --- a/wandb/util.py +++ b/wandb/util.py @@ -41,6 +41,7 @@ Mapping, Optional, Sequence, + Set, TextIO, Tuple, Type, @@ -575,6 +576,68 @@ def matplotlib_contains_images(obj: Any) -> bool: return any(len(ax.images) > 0 for ax in obj.axes) +def _numpy_generic_convert(obj: Any) -> Any: + obj = obj.item() + if isinstance(obj, float) and math.isnan(obj): + obj = None + elif isinstance(obj, np.generic) and ( + obj.dtype.kind == "f" or obj.dtype == "bfloat16" + ): + # obj is a numpy float with precision greater than that of native python float + # (i.e., float96 or float128) or it is of custom type such as bfloat16. + # in these cases, obj.item() does not return a native + # python float (in the first case - to avoid loss of precision, + # so we need to explicitly cast this down to a 64bit float) + obj = float(obj) + return obj + + +def _find_all_matching_keys( + d: Dict, + match_fn: Callable[[Any], bool], + visited: Set[int] = None, + key_path: Tuple[Any, ...] = (), +) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]: + """Recursively find all keys that satisfies a match function. + + Args: + d: The dict to search. + match_fn: The function to determine if the key is a match. + visited: Keep track of visited nodes so we dont recurse forever. + key_path: Keep track of all the keys to get to the current node. + Yields: + (key_path, key): The location where the key was found, and the key + """ + + if visited is None: + visited = set() + me = id(d) + if me not in visited: + visited.add(me) + for key, value in d.items(): + if match_fn(key): + yield key_path, key + if isinstance(value, dict): + yield from _find_all_matching_keys( + value, + match_fn, + visited=visited, + key_path=tuple(list(key_path) + [key]), + ) + + +def _sanitize_numpy_keys(d: Dict) -> Tuple[Dict, bool]: + np_keys = list(_find_all_matching_keys(d, lambda k: isinstance(k, np.generic))) + if not np_keys: + return d, False + for key_path, key in np_keys: + ptr = d + for k in key_path: + ptr = ptr[k] + ptr[_numpy_generic_convert(key)] = ptr.pop(key) + return d, True + + def json_friendly( # noqa: C901 obj: Any, ) -> Union[Tuple[Any, bool], Tuple[Union[None, str, float], bool]]: # noqa: C901 @@ -614,19 +677,7 @@ def json_friendly( # noqa: C901 elif obj.size <= 32: obj = obj.tolist() elif np and isinstance(obj, np.generic): - obj = obj.item() - if isinstance(obj, float) and math.isnan(obj): - obj = None - elif isinstance(obj, np.generic) and ( - obj.dtype.kind == "f" or obj.dtype == "bfloat16" - ): - # obj is a numpy float with precision greater than that of native python float - # (i.e., float96 or float128) or it is of custom type such as bfloat16. - # in these cases, obj.item() does not return a native - # python float (in the first case - to avoid loss of precision, - # so we need to explicitly cast this down to a 64bit float) - obj = float(obj) - + obj = _numpy_generic_convert(obj) elif isinstance(obj, bytes): obj = obj.decode("utf-8") elif isinstance(obj, (datetime, date)): @@ -639,6 +690,8 @@ def json_friendly( # noqa: C901 ) elif isinstance(obj, float) and math.isnan(obj): obj = None + elif isinstance(obj, dict) and np: + obj, converted = _sanitize_numpy_keys(obj) else: converted = False if getsizeof(obj) > VALUE_BYTES_LIMIT: