From c4a0de276fee815b8642ab68b7d7fcc084c23c2e Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Mon, 5 Sep 2022 15:09:31 +0200 Subject: [PATCH] fix: always call `__post_init__` when set in dataclasses --- changes/4477-PrettyWood.md | 1 + pydantic/dataclasses.py | 20 +++++++++++--- tests/test_dataclasses.py | 54 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 changes/4477-PrettyWood.md diff --git a/changes/4477-PrettyWood.md b/changes/4477-PrettyWood.md new file mode 100644 index 00000000000..7800bb091ba --- /dev/null +++ b/changes/4477-PrettyWood.md @@ -0,0 +1 @@ +fix: always call __post_init__ when set in dataclasses diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index f60f04b239a..7da17d50336 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -256,7 +256,6 @@ def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) """ init = dc_cls.__init__ - @wraps(init) def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: if config.extra == Extra.ignore: init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) @@ -285,8 +284,23 @@ def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: if config.post_init_call == 'after_validation': post_init(self, *args, **kwargs) - setattr(dc_cls, '__init__', handle_extra_init) - setattr(dc_cls, '__post_init__', new_post_init) + if is_builtin_dataclass(dc_cls.__bases__[0]) and not hasattr(dc_cls.__bases__[0], '__post_init__'): + # `__post_init__` won't be called so we need to call it manually + @wraps(init) + def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + handle_extra_init(self, *args, **kwargs) + new_post_init(self, *args, **kwargs) + + setattr(dc_cls, '__init__', new_init) + setattr(dc_cls, '__post_init__', new_post_init) + else: + + @wraps(init) + def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + handle_extra_init(self, *args, **kwargs) + + setattr(dc_cls, '__init__', handle_extra_init) + setattr(dc_cls, '__post_init__', new_post_init) else: diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index f29bdfc3c80..bf57fda0ab5 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1395,3 +1395,57 @@ class Bar: @pydantic.dataclasses.dataclass class Foo: a: List[Bar(a=1)] + + +def test_parent_post_init(): + @dataclasses.dataclass + class A: + a: float = 1.1 + + def __post_init__(self): + self.a *= 2 + + @pydantic.dataclasses.dataclass + class B(A): + @validator('a') + def validate_a(cls, value): + value += 1 + return value + + assert B().a == 3.2 # 2 * 1.1 + 1 + + +def test_subclass_post_init_post_parse(): + @dataclasses.dataclass + class A: + a: float = 1.1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init_post_parse__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 1 + return value + + assert B().a == 4.2 # (1.1 + 1) * 2 + + +def test_subclass_post_init(): + @dataclasses.dataclass + class A: + a: float = 1.1 + + @pydantic.dataclasses.dataclass + class B(A): + def __post_init__(self): + self.a *= 2 + + @validator('a') + def validate_a(cls, value): + value += 1 + return value + + assert B().a == 3.2 # 2 * 1.1 + 1