From f9a303fc19b2c7d6803316b845d2a242f66c0ccc Mon Sep 17 00:00:00 2001 From: Sveinung Gundersen Date: Mon, 5 Sep 2022 14:25:17 +0200 Subject: [PATCH] Basic fix of GenericModel cache to detect order of args in Union models [#4474] --- changes/4474-sveinugu.md | 1 + pydantic/generics.py | 10 ++++++--- tests/test_generics.py | 46 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 changes/4474-sveinugu.md diff --git a/changes/4474-sveinugu.md b/changes/4474-sveinugu.md new file mode 100644 index 0000000000..891100b10a --- /dev/null +++ b/changes/4474-sveinugu.md @@ -0,0 +1 @@ +Basic fix of GenericModel cache to detect order of arguments in Union models diff --git a/pydantic/generics.py b/pydantic/generics.py index c43ac3e3e6..a3f52bfee9 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -62,7 +62,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T returned as is. """ - cached = _generic_types_cache.get((cls, params)) + + def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]: + return cls, _params, get_args(_params) + + cached = _generic_types_cache.get(_cache_key(params)) if cached is not None: return cached if cls.__concrete__ and Generic not in cls.__bases__: @@ -128,9 +132,9 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T # Save created model in cache so we don't end up creating duplicate # models that should be identical. - _generic_types_cache[(cls, params)] = created_model + _generic_types_cache[_cache_key(params)] = created_model if len(params) == 1: - _generic_types_cache[(cls, params[0])] = created_model + _generic_types_cache[_cache_key(params[0])] = created_model # Recursively walk class type hints and replace generic typevars # with concrete types that were passed. diff --git a/tests/test_generics.py b/tests/test_generics.py index d65c0196a8..39adc45f20 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -688,6 +688,52 @@ class Model(BaseModel): # same name, but type different, so it's not in cache assert globals()['MyGeneric[Model]__'] is third_concrete +def test_generic_model_caching_detect_order_of_union_args_basic(create_module): + # Basic variant of https://github.com/pydantic/pydantic/issues/4474 + @create_module + def module(): + from typing import Generic, TypeVar, Union + + from pydantic.generics import GenericModel + + t = TypeVar('t') + + class Model(GenericModel, Generic[t]): + data: t + + int_or_float_model = Model[Union[int, float]] + float_or_int_model = Model[Union[float, int]] + + assert type(int_or_float_model(data='1').data) is int + assert type(float_or_int_model(data='1').data) is float + + +@pytest.mark.skip( + reason=""" +Depends on similar issue in CPython itself: https://github.com/python/cpython/issues/86483 +Documented and skipped for possible fix later. +""" +) +def test_generic_model_caching_detect_order_of_union_args_nested(create_module): + # Nested variant of https://github.com/pydantic/pydantic/issues/4474 + @create_module + def module(): + from typing import Generic, List, TypeVar, Union + + from pydantic.generics import GenericModel + + t = TypeVar('t') + + class Model(GenericModel, Generic[t]): + data: t + + int_or_float_model = Model[List[Union[int, float]]] + float_or_int_model = Model[List[Union[float, int]]] + + assert type(int_or_float_model(data=['1']).data[0]) is int + assert type(float_or_int_model(data=['1']).data[0]) is float + + def test_get_caller_frame_info(create_module): @create_module def module():