Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve generic subclass support #2549

Merged
1 change: 1 addition & 0 deletions changes/2007-diabolo-dan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add parameterised subclasses to `__bases__` when constructing new parameterised classes, so that `A <: B => A[int] <: B[int]`.
77 changes: 76 additions & 1 deletion pydantic/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type

Parametrization = Mapping[TypeVarType, Type[Any]]

# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: Dict[Type[Any], Parametrization] = {}
samuelcolvin marked this conversation as resolved.
Show resolved Hide resolved


class GenericModel(BaseModel):
__slots__ = ()
Expand Down Expand Up @@ -86,13 +95,15 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
create_model(
model_name,
__module__=model_module or cls.__module__,
__base__=cls,
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
__config__=None,
__validators__=validators,
**fields,
),
)

_assigned_parameters[created_model] = typevars_map

if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
Expand Down Expand Up @@ -142,6 +153,70 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
params_component = ', '.join(param_names)
return f'{cls.__name__}[{params_component}]'

@classmethod
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
"""
Returns unbound bases of cls parameterised to given type variables

:param typevars_map: Dictionary of type applications for binding subclasses.
Given a generic class `Model` with 2 type variables [S, T]
and a concrete model `Model[str, int]`,
the value `{S: str, T: int}` would be passed to `typevars_map`.
:return: an iterator of generic sub classes, parameterised by `typevars_map`
and other assigned parameters of `cls`

e.g.:
```
class A(GenericModel, Generic[T]):
...

class B(A[V], Generic[V]):
...

assert A[int] in B.__parameterized_bases__({V: int})
```
diabolo-dan marked this conversation as resolved.
Show resolved Hide resolved
"""

def build_base_model(
base_model: Type[GenericModel], mapped_types: Parametrization
) -> Iterator[Type[GenericModel]]:
base_parameters = tuple([mapped_types[param] for param in base_model.__parameters__])
parameterized_base = base_model.__class_getitem__(base_parameters)
if parameterized_base is base_model or parameterized_base is cls:
# Avoid duplication in MRO
return
yield parameterized_base

for base_model in cls.__bases__:
if not issubclass(base_model, GenericModel):
# not a class that can be meaningfully parameterized
continue
elif not getattr(base_model, '__parameters__', None):
# base_model is "GenericModel" (and has no __parameters__)
# or
# base_model is already concrete, and will be included transitively via cls.
continue
elif cls in _assigned_parameters:
if base_model in _assigned_parameters:
# cls is partially parameterised but not from base_model
# e.g. cls = B[S], base_model = A[S]
# B[S][int] should subclass A[int], (and will be transitively via B[int])
# but it's not viable to consistently subclass types with arbitrary construction
# So don't attempt to include A[S][int]
continue
else: # base_model not in _assigned_parameters:
# cls is partially parameterized, base_model is original generic
# e.g. cls = B[str, T], base_model = B[S, T]
# Need to determine the mapping for the base_model parameters
mapped_types: Parametrization = {
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
}
yield from build_base_model(base_model, mapped_types)
else:
# cls is base generic, so base_class has a distinct base
# can construct the Parameterised base model using typevars_map directly
yield from build_base_model(base_model, typevars_map)


def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Expand Down
10 changes: 6 additions & 4 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def create_model(
__model_name: str,
*,
__config__: Optional[Type[BaseConfig]] = None,
__base__: Type['Model'],
__base__: Union[Type['Model'], Tuple[Type['Model'], ...]],
__module__: str = __name__,
__validators__: Dict[str, classmethod] = None,
**field_definitions: Any,
Expand All @@ -889,7 +889,7 @@ def create_model(
__model_name: str,
*,
__config__: Optional[Type[BaseConfig]] = None,
__base__: Optional[Type['Model']] = None,
__base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None,
__module__: str = __name__,
__validators__: Dict[str, classmethod] = None,
**field_definitions: Any,
Expand All @@ -910,8 +910,10 @@ def create_model(
if __base__ is not None:
if __config__ is not None:
raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together')
if not isinstance(__base__, tuple):
__base__ = (__base__,)
else:
__base__ = cast(Type['Model'], BaseModel)
__base__ = (cast(Type['Model'], BaseModel),)

fields = {}
annotations = {}
Expand Down Expand Up @@ -942,7 +944,7 @@ def create_model(
if __config__:
namespace['Config'] = inherit_config(__config__, BaseConfig)

return type(__model_name, (__base__,), namespace)
return type(__model_name, __base__, namespace)


_missing = object()
Expand Down
86 changes: 86 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,92 @@ class SomeGenericModel(GenericModel, Generic[T]):
SomeGenericModel[str](the_alias='qwe')


@skip_36
def test_generic_subclass():
T = TypeVar('T')

class A(GenericModel, Generic[T]):
...

class B(A[T], Generic[T]):
...

assert B[int].__name__ == 'B[int]'
assert issubclass(B[int], B)
assert issubclass(B[int], A[int])
diabolo-dan marked this conversation as resolved.
Show resolved Hide resolved
assert not issubclass(B[int], A[str])


@skip_36
def test_generic_subclass_with_partial_application():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T]):
...

class B(A[S], Generic[T, S]):
...

PartiallyAppliedB = B[str, T]
assert issubclass(PartiallyAppliedB[int], A[int])
assert not issubclass(PartiallyAppliedB[int], A[str])
assert not issubclass(PartiallyAppliedB[str], A[int])


@skip_36
def test_multilevel_generic_binding():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T, S]):
...

class B(A[str, T], Generic[T]):
...

assert B[int].__name__ == 'B[int]'
assert issubclass(B[int], A[str, int])
assert not issubclass(B[str], A[str, int])


@skip_36
def test_generic_subclass_with_extra_type():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T]):
...

class B(A[S], Generic[T, S]):
...

assert B[int, str].__name__ == 'B[int, str]', B[int, str].__name__
assert issubclass(B[str, int], B)
assert issubclass(B[str, int], A[int])
assert not issubclass(B[int, str], A[int])


@skip_36
def test_multi_inheritance_generic_binding():
T = TypeVar('T')

class A(GenericModel, Generic[T]):
...

class B(A[int], Generic[T]):
...

class C(B[str], Generic[T]):
...

assert C[float].__name__ == 'C[float]'
assert issubclass(C[float], B[str])
assert not issubclass(C[float], B[int])
assert issubclass(C[float], A[int])
assert not issubclass(C[float], A[str])


@skip_36
def test_parse_generic_json():
T = TypeVar('T')
Expand Down