Skip to content

Commit

Permalink
Fix TypeError for GenericModel with Callable param (#4653)
Browse files Browse the repository at this point in the history
* Fix TypeError for GenericModel with Callable param

Introduced in 1.10.2, a TypeError would be raised upon creation of a
GenericModel class that used a Callable type parameter. This would
happen when `typing.get_args` returned a list for the first type
agruments in a Callable and pydantic would try to use the value as a
dictionary key. To avoid the issue, we convert the list to a tuple
before using it as a key.

The possible approach of modifying pydantic's `get_args` function to
return a tuple instead of a list didn't work out because the return
values are used in more places, some of which expect the list and not a
tuple.

Fixes #4551

* change as markdown

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
  • Loading branch information
mfulgo and samuelcolvin committed Oct 31, 2022
1 parent e43e455 commit 83cf464
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
1 change: 1 addition & 0 deletions changes/4551-mfulgo.md
@@ -0,0 +1 @@
Fix `GenericModel` with `Callable` param raising a `TypeError`
6 changes: 5 additions & 1 deletion pydantic/generics.py
Expand Up @@ -64,7 +64,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
"""

def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
return cls, _params, get_args(_params)
args = get_args(_params)
# python returns a list for Callables, which is not hashable
if len(args) == 2 and isinstance(args[0], list):
args = (tuple(args[0]), args[1])
return cls, _params, args

cached = _generic_types_cache.get(_cache_key(params))
if cached is not None:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_generics.py
Expand Up @@ -7,6 +7,7 @@
ClassVar,
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Expand Down Expand Up @@ -234,6 +235,32 @@ class Model(GenericModel, Generic[T]):
assert len(_generic_types_cache) == cache_size + 2


def test_cache_keys_are_hashable():
cache_size = len(_generic_types_cache)
T = TypeVar('T')
C = Callable[[str, Dict[str, Any]], Iterable[str]]

class MyGenericModel(GenericModel, Generic[T]):
t: T

# Callable's first params get converted to a list, which is not hashable.
# Make sure we can handle that special case
Simple = MyGenericModel[Callable[[int], str]]
assert len(_generic_types_cache) == cache_size + 2
# Nested Callables
MyGenericModel[Callable[[C], Iterable[str]]]
assert len(_generic_types_cache) == cache_size + 4
MyGenericModel[Callable[[Simple], Iterable[int]]]
assert len(_generic_types_cache) == cache_size + 6
MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]]
assert len(_generic_types_cache) == cache_size + 10

class Model(BaseModel):
x: MyGenericModel[Callable[[C], Iterable[str]]] = Field(...)

assert len(_generic_types_cache) == cache_size + 10


def test_generic_config():
data_type = TypeVar('data_type')

Expand Down

0 comments on commit 83cf464

Please sign in to comment.