From 1088096fee41f8b3826a2d12c75d503c87a09fe2 Mon Sep 17 00:00:00 2001 From: PrettyWood Date: Tue, 1 Dec 2020 01:05:56 +0100 Subject: [PATCH] fix: stdlib dataclass automatically converted into pydantic dataclass can still equal its stdlib dataclass equivalent closes #2162 --- changes/2162-PrettyWood.md | 1 + pydantic/dataclasses.py | 10 +++++++++ tests/test_dataclasses.py | 46 +++++++++++++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 changes/2162-PrettyWood.md diff --git a/changes/2162-PrettyWood.md b/changes/2162-PrettyWood.md new file mode 100644 index 00000000000..091121500f9 --- /dev/null +++ b/changes/2162-PrettyWood.md @@ -0,0 +1 @@ +fix: stdlib `dataclass`, which is automatically converted into _pydantic_ `dataclass`, can still equal its stdlib `dataclass` equivalent \ No newline at end of file diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 53f1427d824..7d3bbd30e7f 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -110,6 +110,15 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: if post_init_post_parse is not None: post_init_post_parse(self, *initvars) + def pydantic_dataclass_eq(self: 'Dataclass', other: Any) -> bool: + """ + To allow to still do equality between pydantic dataclasses and stdlib dataclasses, + we add a custom `__eq__` method + """ + stdlib_dc = self.__class__.__bases__[0] + dc_fields = {k: v for k, v in self.__dict__.items() if k != '__initialised__'} + return stdlib_dc(**dc_fields) == other + # If the class is already a dataclass, __post_init__ will not be called automatically # so no validation will be added. # We hence create dynamically a new dataclass: @@ -129,6 +138,7 @@ def _pydantic_post_init(self: 'Dataclass', *initvars: Any) -> None: (_cls,), { '__annotations__': _cls.__annotations__, + '__eq__': pydantic_dataclass_eq, '__post_init__': _pydantic_post_init, # attrs for pickle to find this class '__module__': __name__, diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index e087455b6e3..49840671bff 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -3,7 +3,7 @@ from collections.abc import Hashable from datetime import datetime from pathlib import Path -from typing import ClassVar, Dict, FrozenSet, List, Optional +from typing import Any, ClassVar, Dict, FrozenSet, Generator, List, Optional, Tuple import pytest @@ -833,3 +833,47 @@ class Config: # ensure the restored dataclass is still a pydantic dataclass with pytest.raises(ValidationError, match='value\n +value is not a valid integer'): restored_obj.dataclass.value = 'value of a wrong type' + + +def gen_dataclasses_tuple() -> Generator[Tuple[Any, Any, bool], None, None]: + @dataclasses.dataclass(frozen=True) + class StdLibFoo: + a: str + b: int + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticFoo: + a: str + b: int + + @dataclasses.dataclass(frozen=True) + class StdLibBar: + c: StdLibFoo + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticBar: + c: PydanticFoo + + @dataclasses.dataclass(frozen=True) + class StdLibBaz: + c: PydanticFoo + + @pydantic.dataclasses.dataclass(frozen=True) + class PydanticBaz: + c: StdLibFoo + + yield StdLibFoo, StdLibBar, True + yield PydanticFoo, PydanticBar, True + yield PydanticFoo, StdLibBaz, True + yield StdLibFoo, PydanticBaz, False + + +@pytest.mark.parametrize('Dataclass1,Dataclass2,is_identical', gen_dataclasses_tuple()) +def test_dataclass_equality(Dataclass1, Dataclass2, is_identical): + foo = Dataclass1(a='Foo', b=1) + bar = Dataclass2(c=foo) + + assert dataclasses.asdict(foo) == dataclasses.asdict(bar.c) + assert dataclasses.astuple(foo) == dataclasses.astuple(bar.c) + assert (foo is bar.c) is is_identical + assert foo == bar.c