Skip to content

Commit

Permalink
fix: prevent RecursionError while using recursive GenericModels (#…
Browse files Browse the repository at this point in the history
…2338)

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
  • Loading branch information
xppt and samuelcolvin committed Feb 26, 2021
1 parent 90df33c commit 8f0980e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
1 change: 1 addition & 0 deletions changes/1370-xppt.md
@@ -0,0 +1 @@
fix: prevent `RecursionError` while using recursive `GenericModel`s
10 changes: 9 additions & 1 deletion pydantic/fields.py
Expand Up @@ -427,7 +427,7 @@ def prepare(self) -> None:
e.g. calling it it multiple times may modify the field and configure it incorrectly.
"""
self._set_default_and_type()
if self.type_.__class__ == ForwardRef:
if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType:
# self.type_ is currently a ForwardRef and there's nothing we can do now,
# user will need to call model.update_forward_refs()
return
Expand Down Expand Up @@ -676,6 +676,8 @@ def validate(
self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None
) -> 'ValidateReturn':

assert self.type_.__class__ is not DeferredType

if self.type_.__class__ is ForwardRef:
assert cls is not None
raise ConfigError(
Expand Down Expand Up @@ -983,3 +985,9 @@ def PrivateAttr(
default,
default_factory=default_factory,
)


class DeferredType:
"""
Used to postpone field preparation, while creating recursive generic models.
"""
61 changes: 41 additions & 20 deletions pydantic/generics.py
Expand Up @@ -9,6 +9,7 @@
Iterable,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Expand All @@ -19,7 +20,7 @@
)

from .class_validators import gather_all_validators
from .fields import FieldInfo, ModelField
from .fields import DeferredType
from .main import BaseModel, create_model
from .typing import display_as_type, get_args, get_origin, typing_base
from .utils import all_identical, lenient_issubclass
Expand Down Expand Up @@ -69,19 +70,15 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
return cls # if arguments are equal to parameters it's the same object

# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
# Create new model with original model as parent inserting fields with DeferredType.
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)

type_hints = get_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
concrete_type_hints: Dict[str, Type[Any]] = {
k: replace_types(v, typevars_map) for k, v in instance_type_hints.items()
}

# Create new model with original model as parent inserting fields with
# updated type hints.
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
fields = _build_generic_fields(cls.__fields__, concrete_type_hints)
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}

model_module, called_globally = get_caller_frame_info()
created_model = cast(
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
Expand Down Expand Up @@ -121,6 +118,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
_generic_types_cache[(cls, params)] = created_model
if len(params) == 1:
_generic_types_cache[(cls, params[0])] = created_model

# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)

return created_model

@classmethod
Expand All @@ -140,11 +142,11 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
return f'{cls.__name__}[{params_component}]'


def replace_types(type_: Any, type_map: Dict[Any, Any]) -> Any:
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
"""Return type with all occurances of `type_map` keys recursively replaced with their values.
:param type_: Any type, class or generic alias
:type_map: Mapping from `TypeVar` instance to concrete types.
:param type_map: Mapping from `TypeVar` instance to concrete types.
:return: New type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
Expand Down Expand Up @@ -218,13 +220,6 @@ def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
yield from iter_contained_typevars(arg)


def _build_generic_fields(
raw_fields: Dict[str, ModelField],
concrete_type_hints: Dict[str, Type[Any]],
) -> Dict[str, Tuple[Type[Any], FieldInfo]]:
return {k: (v, raw_fields[k].field_info) for k, v in concrete_type_hints.items() if k in raw_fields}


def get_caller_frame_info() -> Tuple[Optional[str], bool]:
"""
Used inside a function to check whether it was called globally
Expand All @@ -241,3 +236,29 @@ def get_caller_frame_info() -> Tuple[Optional[str], bool]:
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals


def _prepare_model_fields(
created_model: Type[GenericModel],
fields: Mapping[str, Any],
instance_type_hints: Mapping[str, type],
typevars_map: Mapping[Any, type],
) -> None:
"""
Replace DeferredType fields with concrete type hints and prepare them.
"""

for key, field in created_model.__fields__.items():
if key not in fields:
assert field.type_.__class__ is not DeferredType
# https://github.com/nedbat/coveragepy/issues/198
continue # pragma: no cover

assert field.type_.__class__ is DeferredType, field.type_.__class__

field_type_hint = instance_type_hints[key]
concrete_type = replace_types(field_type_hint, typevars_map)
field.type_ = concrete_type
field.outer_type_ = concrete_type
field.prepare()
created_model.__annotations__[key] = concrete_type
24 changes: 24 additions & 0 deletions tests/test_generics.py
Expand Up @@ -1015,3 +1015,27 @@ class Model(GenericModel, Generic[T, U]):
Model[str, U].__concrete__ is False
Model[str, U].__parameters__ == [U]
Model[str, int].__concrete__ is False


@skip_36
def test_generic_recursive_models(create_module):
@create_module
def module():
from typing import Generic, TypeVar, Union

from pydantic.generics import GenericModel

T = TypeVar('T')

class Model1(GenericModel, Generic[T]):
ref: 'Model2[T]'

class Model2(GenericModel, Generic[T]):
ref: Union[T, Model1[T]]

Model1.update_forward_refs()

Model1 = module.Model1
Model2 = module.Model2
result = Model1[str].parse_obj(dict(ref=dict(ref=dict(ref=dict(ref=123)))))
assert result == Model1(ref=Model2(ref=Model1(ref=Model2(ref='123'))))

0 comments on commit 8f0980e

Please sign in to comment.