From 0ee07e3cc1ccd822227545936c1be0c94f84ad54 Mon Sep 17 00:00:00 2001 From: Fabien Aulaire <306648+faulaire@users.noreply.github.com> Date: Fri, 24 Jun 2022 16:08:26 +0200 Subject: [PATCH] Enum deterministic hashing (#9212) --- dask/base.py | 6 ++++++ dask/tests/test_base.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/dask/base.py b/dask/base.py index 2e3e88aea12..60b0f711ef5 100644 --- a/dask/base.py +++ b/dask/base.py @@ -13,6 +13,7 @@ from collections.abc import Callable, Iterator, Mapping from concurrent.futures import Executor from contextlib import contextmanager +from enum import Enum from functools import partial from numbers import Integral, Number from operator import getitem @@ -999,6 +1000,11 @@ def normalize_range(r): return list(map(normalize_token, [r.start, r.stop, r.step])) +@normalize_token.register(Enum) +def normalize_enum(e): + return type(e).__name__, e.name, e.value + + @normalize_token.register(object) def normalize_object(o): method = getattr(o, "__dask_tokenize__", None) diff --git a/dask/tests/test_base.py b/dask/tests/test_base.py index 8852d4cd4fc..709c745e933 100644 --- a/dask/tests/test_base.py +++ b/dask/tests/test_base.py @@ -7,6 +7,7 @@ import time from collections import OrderedDict from concurrent.futures import Executor +from enum import Enum, Flag, IntEnum, IntFlag from operator import add, mul from typing import Union @@ -425,6 +426,16 @@ def test_tokenize_ordered_dict(): assert tokenize(a) != tokenize(c) +@pytest.mark.parametrize("enum_type", [Enum, IntEnum, IntFlag, Flag]) +def test_tokenize_enum(enum_type): + class Color(enum_type): + RED = 1 + BLUE = 2 + + assert tokenize(Color.RED) == tokenize(Color.RED) + assert tokenize(Color.RED) != tokenize(Color.BLUE) + + ADataClass = dataclasses.make_dataclass("ADataClass", [("a", int)]) BDataClass = dataclasses.make_dataclass("BDataClass", [("a", Union[int, float])]) # type: ignore