Skip to content

Commit

Permalink
Fix issue with self-referencing dataclass (#3713)
Browse files Browse the repository at this point in the history
* Fix issue with self-referencing dataclass

* Fix mypy issue
  • Loading branch information
uriyyo committed May 11, 2022
1 parent faee330 commit 42acd8f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/3675-uriyyo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix issue with self-referencing dataclass
9 changes: 8 additions & 1 deletion pydantic/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def _process_class(

validators = gather_all_validators(cls)
cls.__pydantic_model__ = create_model(
cls.__name__, __config__=config, __module__=_cls.__module__, __validators__=validators, **field_definitions
cls.__name__,
__config__=config,
__module__=_cls.__module__,
__validators__=validators,
__cls_kwargs__={'__resolve_forward_refs__': False},
**field_definitions,
)

cls.__initialised__ = False
Expand All @@ -196,6 +201,8 @@ def _process_class(
if cls.__pydantic_model__.__config__.validate_assignment and not frozen:
cls.__setattr__ = setattr_validate_assignment # type: ignore[assignment]

cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})

return cls


Expand Down
1 change: 1 addition & 0 deletions pydantic/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
__config__=None,
__validators__=validators,
__cls_kwargs__=None,
**fields,
),
)
Expand Down
16 changes: 12 additions & 4 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
class_vars.update(base.__class_vars__)
hash_func = base.__hash__

resolve_forward_refs = kwargs.pop('__resolve_forward_refs__', True)
allowed_config_kwargs: SetStr = {
key
for key in dir(config)
Expand Down Expand Up @@ -289,7 +290,8 @@ def is_untouched(v: Any) -> bool:
cls = super().__new__(mcs, name, bases, new_namespace, **kwargs)
# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config))
cls.__try_update_forward_refs__()
if resolve_forward_refs:
cls.__try_update_forward_refs__()

return cls

Expand Down Expand Up @@ -765,12 +767,12 @@ def _get_value(
return v

@classmethod
def __try_update_forward_refs__(cls) -> None:
def __try_update_forward_refs__(cls, **localns: Any) -> None:
"""
Same as update_forward_refs but will not raise exception
when forward references are not defined.
"""
update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, {}, (NameError,))
update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns, (NameError,))

@classmethod
def update_forward_refs(cls, **localns: Any) -> None:
Expand Down Expand Up @@ -892,6 +894,7 @@ def create_model(
__base__: None = None,
__module__: str = __name__,
__validators__: Dict[str, 'AnyClassMethod'] = None,
__cls_kwargs__: Dict[str, Any] = None,
**field_definitions: Any,
) -> Type['BaseModel']:
...
Expand All @@ -905,6 +908,7 @@ def create_model(
__base__: Union[Type['Model'], Tuple[Type['Model'], ...]],
__module__: str = __name__,
__validators__: Dict[str, 'AnyClassMethod'] = None,
__cls_kwargs__: Dict[str, Any] = None,
**field_definitions: Any,
) -> Type['Model']:
...
Expand All @@ -917,6 +921,7 @@ def create_model(
__base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None,
__module__: str = __name__,
__validators__: Dict[str, 'AnyClassMethod'] = None,
__cls_kwargs__: Dict[str, Any] = None,
**field_definitions: Any,
) -> Type['Model']:
"""
Expand All @@ -926,6 +931,7 @@ def create_model(
:param __base__: base class for the new model to inherit from
:param __module__: module of the created model
:param __validators__: a dict of method names and @validator class methods
:param __cls_kwargs__: a dict for class creation
:param field_definitions: fields of the model (or extra fields if a base is supplied)
in the format `<name>=(<type>, <default default>)` or `<name>=<default value>, e.g.
`foobar=(str, ...)` or `foobar=123`, or, for complex use-cases, in the format
Expand All @@ -940,6 +946,8 @@ def create_model(
else:
__base__ = (cast(Type['Model'], BaseModel),)

__cls_kwargs__ = __cls_kwargs__ or {}

fields = {}
annotations = {}

Expand Down Expand Up @@ -969,7 +977,7 @@ def create_model(
if __config__:
namespace['Config'] = inherit_config(__config__, BaseConfig)

return type(__model_name, __base__, namespace)
return type(__model_name, __base__, namespace, **__cls_kwargs__)


_missing = object()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,3 +989,11 @@ def __new__(cls, *args, **kwargs):
instance = cls(a=test_string)
assert instance._special_property == 1
assert instance.a == test_string


def test_self_reference_dataclass():
@pydantic.dataclasses.dataclass
class MyDataclass:
self_reference: 'MyDataclass'

assert MyDataclass.__pydantic_model__.__fields__['self_reference'].type_ is MyDataclass

0 comments on commit 42acd8f

Please sign in to comment.