Skip to content

Commit

Permalink
Improve generic subclass support (#2549)
Browse files Browse the repository at this point in the history
* Derive concrete subclasses for parameterised generics

* Resolve type issues

* Add negative assertions to generic subclass tests

* Remove incorrect subclassing of partial.

The type was incorrectly being picked up for this style of subclassing,
and it can be regardless inferred through cls.

* Apply feedback:

* Improve parameterisation explanation
* fix typos
* Alias Parameterisation type

* Apply suggestions from code review

* start docstring with newline.
* Use None as default over empty tuple.

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>

* Combine _assigned_parameters cases in __paramaterized_bases__ of generics

* Add description for the `_assigned_parameters` variable.

Co-authored-by: Samuel Colvin <samcolvin@gmail.com>
Co-authored-by: Samuel Colvin <s@muelcolvin.com>
  • Loading branch information
3 people committed Dec 5, 2021
1 parent a35cde9 commit e71f53d
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 5 deletions.
1 change: 1 addition & 0 deletions changes/2007-diabolo-dan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add parameterised subclasses to `__bases__` when constructing new parameterised classes, so that `A <: B => A[int] <: B[int]`.
77 changes: 76 additions & 1 deletion pydantic/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type

Parametrization = Mapping[TypeVarType, Type[Any]]

# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: Dict[Type[Any], Parametrization] = {}


class GenericModel(BaseModel):
__slots__ = ()
Expand Down Expand Up @@ -86,13 +95,15 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
create_model(
model_name,
__module__=model_module or cls.__module__,
__base__=cls,
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
__config__=None,
__validators__=validators,
**fields,
),
)

_assigned_parameters[created_model] = typevars_map

if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
Expand Down Expand Up @@ -142,6 +153,70 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
params_component = ', '.join(param_names)
return f'{cls.__name__}[{params_component}]'

@classmethod
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
"""
Returns unbound bases of cls parameterised to given type variables
:param typevars_map: Dictionary of type applications for binding subclasses.
Given a generic class `Model` with 2 type variables [S, T]
and a concrete model `Model[str, int]`,
the value `{S: str, T: int}` would be passed to `typevars_map`.
:return: an iterator of generic sub classes, parameterised by `typevars_map`
and other assigned parameters of `cls`
e.g.:
```
class A(GenericModel, Generic[T]):
...
class B(A[V], Generic[V]):
...
assert A[int] in B.__parameterized_bases__({V: int})
```
"""

def build_base_model(
base_model: Type[GenericModel], mapped_types: Parametrization
) -> Iterator[Type[GenericModel]]:
base_parameters = tuple([mapped_types[param] for param in base_model.__parameters__])
parameterized_base = base_model.__class_getitem__(base_parameters)
if parameterized_base is base_model or parameterized_base is cls:
# Avoid duplication in MRO
return
yield parameterized_base

for base_model in cls.__bases__:
if not issubclass(base_model, GenericModel):
# not a class that can be meaningfully parameterized
continue
elif not getattr(base_model, '__parameters__', None):
# base_model is "GenericModel" (and has no __parameters__)
# or
# base_model is already concrete, and will be included transitively via cls.
continue
elif cls in _assigned_parameters:
if base_model in _assigned_parameters:
# cls is partially parameterised but not from base_model
# e.g. cls = B[S], base_model = A[S]
# B[S][int] should subclass A[int], (and will be transitively via B[int])
# but it's not viable to consistently subclass types with arbitrary construction
# So don't attempt to include A[S][int]
continue
else: # base_model not in _assigned_parameters:
# cls is partially parameterized, base_model is original generic
# e.g. cls = B[str, T], base_model = B[S, T]
# Need to determine the mapping for the base_model parameters
mapped_types: Parametrization = {
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
}
yield from build_base_model(base_model, mapped_types)
else:
# cls is base generic, so base_class has a distinct base
# can construct the Parameterised base model using typevars_map directly
yield from build_base_model(base_model, typevars_map)


def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Expand Down
10 changes: 6 additions & 4 deletions pydantic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def create_model(
__model_name: str,
*,
__config__: Optional[Type[BaseConfig]] = None,
__base__: Type['Model'],
__base__: Union[Type['Model'], Tuple[Type['Model'], ...]],
__module__: str = __name__,
__validators__: Dict[str, classmethod] = None,
**field_definitions: Any,
Expand All @@ -889,7 +889,7 @@ def create_model(
__model_name: str,
*,
__config__: Optional[Type[BaseConfig]] = None,
__base__: Optional[Type['Model']] = None,
__base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None,
__module__: str = __name__,
__validators__: Dict[str, classmethod] = None,
**field_definitions: Any,
Expand All @@ -910,8 +910,10 @@ def create_model(
if __base__ is not None:
if __config__ is not None:
raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together')
if not isinstance(__base__, tuple):
__base__ = (__base__,)
else:
__base__ = cast(Type['Model'], BaseModel)
__base__ = (cast(Type['Model'], BaseModel),)

fields = {}
annotations = {}
Expand Down Expand Up @@ -942,7 +944,7 @@ def create_model(
if __config__:
namespace['Config'] = inherit_config(__config__, BaseConfig)

return type(__model_name, (__base__,), namespace)
return type(__model_name, __base__, namespace)


_missing = object()
Expand Down
86 changes: 86 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,92 @@ class SomeGenericModel(GenericModel, Generic[T]):
SomeGenericModel[str](the_alias='qwe')


@skip_36
def test_generic_subclass():
T = TypeVar('T')

class A(GenericModel, Generic[T]):
...

class B(A[T], Generic[T]):
...

assert B[int].__name__ == 'B[int]'
assert issubclass(B[int], B)
assert issubclass(B[int], A[int])
assert not issubclass(B[int], A[str])


@skip_36
def test_generic_subclass_with_partial_application():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T]):
...

class B(A[S], Generic[T, S]):
...

PartiallyAppliedB = B[str, T]
assert issubclass(PartiallyAppliedB[int], A[int])
assert not issubclass(PartiallyAppliedB[int], A[str])
assert not issubclass(PartiallyAppliedB[str], A[int])


@skip_36
def test_multilevel_generic_binding():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T, S]):
...

class B(A[str, T], Generic[T]):
...

assert B[int].__name__ == 'B[int]'
assert issubclass(B[int], A[str, int])
assert not issubclass(B[str], A[str, int])


@skip_36
def test_generic_subclass_with_extra_type():
T = TypeVar('T')
S = TypeVar('S')

class A(GenericModel, Generic[T]):
...

class B(A[S], Generic[T, S]):
...

assert B[int, str].__name__ == 'B[int, str]', B[int, str].__name__
assert issubclass(B[str, int], B)
assert issubclass(B[str, int], A[int])
assert not issubclass(B[int, str], A[int])


@skip_36
def test_multi_inheritance_generic_binding():
T = TypeVar('T')

class A(GenericModel, Generic[T]):
...

class B(A[int], Generic[T]):
...

class C(B[str], Generic[T]):
...

assert C[float].__name__ == 'C[float]'
assert issubclass(C[float], B[str])
assert not issubclass(C[float], B[int])
assert issubclass(C[float], A[int])
assert not issubclass(C[float], A[str])


@skip_36
def test_parse_generic_json():
T = TypeVar('T')
Expand Down

0 comments on commit e71f53d

Please sign in to comment.