From b803a4584563fb8a80720e0d4730e8bc0dd1e446 Mon Sep 17 00:00:00 2001 From: Fabien Aulaire Date: Fri, 24 Jun 2022 13:50:12 +0200 Subject: [PATCH 1/2] Enum deterministic hashing --- dask/base.py | 6 ++++++ dask/tests/test_base.py | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/dask/base.py b/dask/base.py index 2e3e88aea12..60b0f711ef5 100644 --- a/dask/base.py +++ b/dask/base.py @@ -13,6 +13,7 @@ from collections.abc import Callable, Iterator, Mapping from concurrent.futures import Executor from contextlib import contextmanager +from enum import Enum from functools import partial from numbers import Integral, Number from operator import getitem @@ -999,6 +1000,11 @@ def normalize_range(r): return list(map(normalize_token, [r.start, r.stop, r.step])) +@normalize_token.register(Enum) +def normalize_enum(e): + return type(e).__name__, e.name, e.value + + @normalize_token.register(object) def normalize_object(o): method = getattr(o, "__dask_tokenize__", None) diff --git a/dask/tests/test_base.py b/dask/tests/test_base.py index 8852d4cd4fc..e1283a00d7c 100644 --- a/dask/tests/test_base.py +++ b/dask/tests/test_base.py @@ -7,6 +7,7 @@ import time from collections import OrderedDict from concurrent.futures import Executor +from enum import Enum, Flag, IntEnum, IntFlag from operator import add, mul from typing import Union @@ -425,6 +426,15 @@ def test_tokenize_ordered_dict(): assert tokenize(a) != tokenize(c) +@pytest.mark.parametrize("enum_type", [Enum, IntEnum, IntFlag, Flag]) +def test_tokenize_enum(enum_type): + class Color(enum_type): + RED = 1 + BLUE = 2 + + assert tokenize(Color.RED) == tokenize(Color.RED) + + ADataClass = dataclasses.make_dataclass("ADataClass", [("a", int)]) BDataClass = dataclasses.make_dataclass("BDataClass", [("a", Union[int, float])]) # type: ignore From 27be1545478e758bf47488dc4bdf190107f368da Mon Sep 17 00:00:00 2001 From: Fabien Aulaire Date: Fri, 24 Jun 2022 14:04:40 +0200 Subject: [PATCH 2/2] Enum hashing : Add another test case --- dask/tests/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dask/tests/test_base.py b/dask/tests/test_base.py index e1283a00d7c..709c745e933 100644 --- a/dask/tests/test_base.py +++ b/dask/tests/test_base.py @@ -433,6 +433,7 @@ class Color(enum_type): BLUE = 2 assert tokenize(Color.RED) == tokenize(Color.RED) + assert tokenize(Color.RED) != tokenize(Color.BLUE) ADataClass = dataclasses.make_dataclass("ADataClass", [("a", int)])