Skip to content

Commit

Permalink
Basic fix of GenericModel cache to detect order of args in Union mode…
Browse files Browse the repository at this point in the history
  • Loading branch information
sveinugu committed Sep 5, 2022
1 parent 317bef3 commit f9a303f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
1 change: 1 addition & 0 deletions changes/4474-sveinugu.md
@@ -0,0 +1 @@
Basic fix of GenericModel cache to detect order of arguments in Union models
10 changes: 7 additions & 3 deletions pydantic/generics.py
Expand Up @@ -62,7 +62,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
returned as is.
"""
cached = _generic_types_cache.get((cls, params))

def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
return cls, _params, get_args(_params)

cached = _generic_types_cache.get(_cache_key(params))
if cached is not None:
return cached
if cls.__concrete__ and Generic not in cls.__bases__:
Expand Down Expand Up @@ -128,9 +132,9 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T

# Save created model in cache so we don't end up creating duplicate
# models that should be identical.
_generic_types_cache[(cls, params)] = created_model
_generic_types_cache[_cache_key(params)] = created_model
if len(params) == 1:
_generic_types_cache[(cls, params[0])] = created_model
_generic_types_cache[_cache_key(params[0])] = created_model

# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
Expand Down
46 changes: 46 additions & 0 deletions tests/test_generics.py
Expand Up @@ -688,6 +688,52 @@ class Model(BaseModel): # same name, but type different, so it's not in cache
assert globals()['MyGeneric[Model]__'] is third_concrete


def test_generic_model_caching_detect_order_of_union_args_basic(create_module):
# Basic variant of https://github.com/pydantic/pydantic/issues/4474
@create_module
def module():
from typing import Generic, TypeVar, Union

from pydantic.generics import GenericModel

t = TypeVar('t')

class Model(GenericModel, Generic[t]):
data: t

int_or_float_model = Model[Union[int, float]]
float_or_int_model = Model[Union[float, int]]

assert type(int_or_float_model(data='1').data) is int
assert type(float_or_int_model(data='1').data) is float


@pytest.mark.skip(
reason="""
Depends on similar issue in CPython itself: https://github.com/python/cpython/issues/86483
Documented and skipped for possible fix later.
"""
)
def test_generic_model_caching_detect_order_of_union_args_nested(create_module):
# Nested variant of https://github.com/pydantic/pydantic/issues/4474
@create_module
def module():
from typing import Generic, List, TypeVar, Union

from pydantic.generics import GenericModel

t = TypeVar('t')

class Model(GenericModel, Generic[t]):
data: t

int_or_float_model = Model[List[Union[int, float]]]
float_or_int_model = Model[List[Union[float, int]]]

assert type(int_or_float_model(data=['1']).data[0]) is int
assert type(float_or_int_model(data=['1']).data[0]) is float


def test_get_caller_frame_info(create_module):
@create_module
def module():
Expand Down

0 comments on commit f9a303f

Please sign in to comment.