Skip to content

Commit

Permalink
Add fix for pydantic#4551 - ensure cache keys are hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
David Montague committed Jan 22, 2023
1 parent 183582a commit ad97c39
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
6 changes: 5 additions & 1 deletion pydantic/main.py
Expand Up @@ -602,7 +602,11 @@ def __repr_args__(self) -> _repr.ReprArgs:

def __class_getitem__(cls, typevar_values: type[Any] | tuple[type[Any], ...]) -> type[Any]:
def _cache_key(_params: Any) -> tuple[type[Any], Any, tuple[Any, ...]]:
return cls, _params, typing_extensions.get_args(_params)
args = typing.get_args(_params)
# python returns a list for Callables, which is not hashable
if len(args) == 2 and isinstance(args[0], list):
args = (tuple(args[0]), args[1])
return cls, _params, args

cached = _generic_types_cache.get(_cache_key(typevar_values))
if cached is not None:
Expand Down
30 changes: 29 additions & 1 deletion tests/test_generics.py
Expand Up @@ -7,6 +7,7 @@
ClassVar,
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Expand All @@ -21,7 +22,8 @@
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, Json, ValidationError, root_validator, validator
from pydantic.generics import GenericModel, _generic_types_cache, iter_contained_typevars, replace_types
from pydantic.generics import GenericModel, iter_contained_typevars, replace_types
from pydantic.main import _generic_types_cache


def test_generic_name():
Expand Down Expand Up @@ -241,6 +243,32 @@ class Model(GenericModel, Generic[T]):
assert len(_generic_types_cache) == cache_size + 2


def test_cache_keys_are_hashable():
cache_size = len(_generic_types_cache)
T = TypeVar('T')
C = Callable[[str, Dict[str, Any]], Iterable[str]]

class MyGenericModel(BaseModel, Generic[T]):
t: T

# Callable's first params get converted to a list, which is not hashable.
# Make sure we can handle that special case
Simple = MyGenericModel[Callable[[int], str]]
assert len(_generic_types_cache) == cache_size + 2
# Nested Callables
MyGenericModel[Callable[[C], Iterable[str]]]
assert len(_generic_types_cache) == cache_size + 4
MyGenericModel[Callable[[Simple], Iterable[int]]]
assert len(_generic_types_cache) == cache_size + 6
MyGenericModel[Callable[[MyGenericModel[C]], Iterable[int]]]
assert len(_generic_types_cache) == cache_size + 10

class Model(BaseModel):
x: MyGenericModel[Callable[[C], Iterable[str]]] = Field(...)

assert len(_generic_types_cache) == cache_size + 10


@pytest.mark.xfail(reason='working on V2')
def test_generic_config():
data_type = TypeVar('data_type')
Expand Down

0 comments on commit ad97c39

Please sign in to comment.