Skip to content

Commit

Permalink
fix: dataclass wrapper was not always called
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Sep 5, 2022
1 parent eccd85e commit d5fe6f8
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
1 change: 1 addition & 0 deletions changes/4477-PrettyWood.md
@@ -0,0 +1 @@
fix: dataclass wrapper was not always called
27 changes: 25 additions & 2 deletions pydantic/dataclasses.py
Expand Up @@ -34,7 +34,20 @@ class M:
import sys
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Optional,
Set,
Type,
TypeVar,
Union,
overload,
)

from typing_extensions import dataclass_transform

Expand Down Expand Up @@ -184,7 +197,9 @@ def dataclass(
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
import dataclasses

if is_builtin_dataclass(cls):
has_extra_args_or_methods = _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]) # type: ignore

if is_builtin_dataclass(cls) and not has_extra_args_or_methods:
dc_cls_doc = ''
dc_cls = DataclassProxy(cls)
default_validate_on_init = False
Expand Down Expand Up @@ -418,6 +433,14 @@ def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value:
object.__setattr__(self, name, value)


def _extra_dc_args(cls: Type[Any]) -> Set[str]:
return {
x
for x in dir(cls)
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
}


def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
Whether a class is a stdlib dataclass
Expand Down
77 changes: 77 additions & 0 deletions tests/test_dataclasses.py
Expand Up @@ -1395,3 +1395,80 @@ class Bar:
@pydantic.dataclasses.dataclass
class Foo:
a: List[Bar(a=1)]


def test_parent_post_init():
@dataclasses.dataclass
class A:
a: float = 1

def __post_init__(self):
self.a *= 2

@pydantic.dataclasses.dataclass
class B(A):
@validator('a')
def validate_a(cls, value):
value += 3
return value

assert B().a == 5 # 1 * 2 + 3


def test_subclass_post_init_post_parse():
@dataclasses.dataclass
class A:
a: float = 1

@pydantic.dataclasses.dataclass
class B(A):
def __post_init_post_parse__(self):
self.a *= 2

@validator('a')
def validate_a(cls, value):
value += 3
return value

assert B().a == 8 # (1 + 3) * 2


def test_subclass_post_init():
@dataclasses.dataclass
class A:
a: int = 1

@pydantic.dataclasses.dataclass
class B(A):
def __post_init__(self):
self.a *= 2

@validator('a')
def validate_a(cls, value):
value += 3
return value

assert B().a == 5 # 1 * 2 + 3


def test_subclass_post_init_inheritance():
@dataclasses.dataclass
class A:
a: int = 1

@pydantic.dataclasses.dataclass
class B(A):
def __post_init__(self):
self.a *= 2

@validator('a')
def validate_a(cls, value):
value += 3
return value

@pydantic.dataclasses.dataclass
class C(B):
def __post_init__(self):
self.a *= 3

assert C().a == 6 # 1 * 3 + 3

0 comments on commit d5fe6f8

Please sign in to comment.