Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitize numpy generics in keys #4146

Merged
merged 7 commits into from Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/functional_tests/t0_main/core/t7_config_numpy_keys.py
@@ -0,0 +1,26 @@
#!/usr/bin/env python
"""
Use numpy in nested dict keys

---
plugin:
- wandb
assert:
- :wandb:runs_len: 1
- :wandb:runs[0][config][this]: 2
- :wandb:runs[0][config][ok]: {"3": 4}
- :wandb:runs[0][config][deeper][again]: {"9": 34}
- :wandb:runs[0][config][bad]: {"22": 4}
- :wandb:runs[0][exitcode]: 0
"""
import numpy as np
import wandb

wandb.init(
config={
"this": 2,
"ok": {3: 4},
"deeper": {"again": {9: 34}},
"bad": {np.int64(22): 4},
}
)
35 changes: 35 additions & 0 deletions tests/unit_tests/test_util.py
Expand Up @@ -688,3 +688,38 @@ def test_resolve_aliases():

aliases = util._resolve_aliases("boom")
assert aliases == ["boom", "latest"]


# Compute recursive dicts for tests
d_recursive1i = {1: 2, 3: {4: 5}}
d_recursive1i["_"] = d_recursive1i
d_recursive2i = {1: 2, 3: {np.int64(44): 5}}
d_recursive2i["_"] = d_recursive2i
d_recursive2o = {1: 2, 3: {44: 5}}
d_recursive2o["_"] = d_recursive2o


@pytest.mark.parametrize(
"dict_input, dict_output",
[
({}, None),
({1: 2}, None),
({1: np.int64(3)}, None), # dont care about values
({np.int64(3): 4}, {3: 4}), # top-level
({1: {np.int64(3): 4}}, {1: {3: 4}}), # nested key
({1: {np.int32(2): 4}}, {1: {2: 4}}), # nested key
(d_recursive1i, None), # recursive, no numpy
(d_recursive2i, d_recursive2o), # recursive, numpy
],
)
def test_sanitize_numpy_keys(dict_input, dict_output):
dict_output = dict_output.copy() if dict_output is not None else None
output, converted = util._sanitize_numpy_keys(dict_input)
assert converted == (dict_output is not None)

# pytest assert cant handle recursive dicts
if dict_output and "_" in dict_output:
output.pop("_")
dict_output.pop("_")

assert output == (dict_output or dict_input)
79 changes: 66 additions & 13 deletions wandb/util.py
Expand Up @@ -41,6 +41,7 @@
Mapping,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
Expand Down Expand Up @@ -573,6 +574,68 @@ def matplotlib_contains_images(obj: Any) -> bool:
return any(len(ax.images) > 0 for ax in obj.axes)


def _numpy_generic_convert(obj: Any) -> Any:
obj = obj.item()
if isinstance(obj, float) and math.isnan(obj):
obj = None
elif isinstance(obj, np.generic) and (
obj.dtype.kind == "f" or obj.dtype == "bfloat16"
):
# obj is a numpy float with precision greater than that of native python float
# (i.e., float96 or float128) or it is of custom type such as bfloat16.
# in these cases, obj.item() does not return a native
# python float (in the first case - to avoid loss of precision,
# so we need to explicitly cast this down to a 64bit float)
obj = float(obj)
return obj


def _find_all_matching_keys(
d: Dict,
match_fn: Callable[[Any], bool],
visited: Set[int] = None,
key_path: Tuple[Any, ...] = (),
) -> Generator[Tuple[Tuple[Any, ...], Any], None, None]:
"""Recursively find all keys that satisfies a match function.

Args:
d: The dict to search.
match_fn: The function to determine if the key is a match.
visited: Keep track of visited nodes so we dont recurse forever.
key_path: Keep track of all the keys to get to the current node.
Yields:
(key_path, key): The location where the key was found, and the key
"""

if visited is None:
visited = set()
me = id(d)
if me not in visited:
visited.add(me)
for key, value in d.items():
if match_fn(key):
yield key_path, key
if isinstance(value, dict):
yield from _find_all_matching_keys(
value,
match_fn,
visited=visited,
key_path=tuple(list(key_path) + [key]),
)


def _sanitize_numpy_keys(d: Dict) -> Tuple[Dict, bool]:
np_keys = list(_find_all_matching_keys(d, lambda k: isinstance(k, np.generic)))
if not np_keys:
return d, False
for key_path, key in np_keys:
ptr = d
for k in key_path:
ptr = ptr[k]
ptr[_numpy_generic_convert(key)] = ptr.pop(key)
return d, True


def json_friendly( # noqa: C901
obj: Any,
) -> Union[Tuple[Any, bool], Tuple[Union[None, str, float], bool]]: # noqa: C901
Expand Down Expand Up @@ -612,19 +675,7 @@ def json_friendly( # noqa: C901
elif obj.size <= 32:
obj = obj.tolist()
elif np and isinstance(obj, np.generic):
obj = obj.item()
if isinstance(obj, float) and math.isnan(obj):
obj = None
elif isinstance(obj, np.generic) and (
obj.dtype.kind == "f" or obj.dtype == "bfloat16"
):
# obj is a numpy float with precision greater than that of native python float
# (i.e., float96 or float128) or it is of custom type such as bfloat16.
# in these cases, obj.item() does not return a native
# python float (in the first case - to avoid loss of precision,
# so we need to explicitly cast this down to a 64bit float)
obj = float(obj)

obj = _numpy_generic_convert(obj)
elif isinstance(obj, bytes):
obj = obj.decode("utf-8")
elif isinstance(obj, (datetime, date)):
Expand All @@ -637,6 +688,8 @@ def json_friendly( # noqa: C901
)
elif isinstance(obj, float) and math.isnan(obj):
obj = None
elif isinstance(obj, dict) and np:
obj, converted = _sanitize_numpy_keys(obj)
else:
converted = False
if getsizeof(obj) > VALUE_BYTES_LIMIT:
Expand Down