Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic fix of GenericModel cache to detect order of arguments in Union models #4482

Merged
merged 1 commit into from Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/4474-sveinugu.md
@@ -0,0 +1 @@
Basic fix of GenericModel cache to detect order of arguments in Union models
10 changes: 7 additions & 3 deletions pydantic/generics.py
Expand Up @@ -62,7 +62,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
returned as is.

"""
cached = _generic_types_cache.get((cls, params))

def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
return cls, _params, get_args(_params)

cached = _generic_types_cache.get(_cache_key(params))
if cached is not None:
return cached
if cls.__concrete__ and Generic not in cls.__bases__:
Expand Down Expand Up @@ -128,9 +132,9 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T

# Save created model in cache so we don't end up creating duplicate
# models that should be identical.
_generic_types_cache[(cls, params)] = created_model
_generic_types_cache[_cache_key(params)] = created_model
if len(params) == 1:
_generic_types_cache[(cls, params[0])] = created_model
_generic_types_cache[_cache_key(params[0])] = created_model

# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
Expand Down
46 changes: 46 additions & 0 deletions tests/test_generics.py
Expand Up @@ -688,6 +688,52 @@ class Model(BaseModel): # same name, but type different, so it's not in cache
assert globals()['MyGeneric[Model]__'] is third_concrete


def test_generic_model_caching_detect_order_of_union_args_basic(create_module):
# Basic variant of https://github.com/pydantic/pydantic/issues/4474
@create_module
def module():
from typing import Generic, TypeVar, Union

from pydantic.generics import GenericModel

t = TypeVar('t')

class Model(GenericModel, Generic[t]):
data: t

int_or_float_model = Model[Union[int, float]]
float_or_int_model = Model[Union[float, int]]

assert type(int_or_float_model(data='1').data) is int
assert type(float_or_int_model(data='1').data) is float


@pytest.mark.skip(
reason="""
Depends on similar issue in CPython itself: https://github.com/python/cpython/issues/86483
Documented and skipped for possible fix later.
"""
)
def test_generic_model_caching_detect_order_of_union_args_nested(create_module):
# Nested variant of https://github.com/pydantic/pydantic/issues/4474
@create_module
def module():
from typing import Generic, List, TypeVar, Union

from pydantic.generics import GenericModel

t = TypeVar('t')

class Model(GenericModel, Generic[t]):
data: t

int_or_float_model = Model[List[Union[int, float]]]
float_or_int_model = Model[List[Union[float, int]]]

assert type(int_or_float_model(data=['1']).data[0]) is int
assert type(float_or_int_model(data=['1']).data[0]) is float


def test_get_caller_frame_info(create_module):
@create_module
def module():
Expand Down