Skip to content

Commit

Permalink
Support for variadic type aliases (#15219)
Browse files Browse the repository at this point in the history
Fixes #15062 

Implementing "happy path" took like couple dozen lines, but there are *a
lot* of edge cases, e.g. where we need to fail gracefully. Also of
course I checked my implementation (mostly) works for recursive variadic
aliases :-) see test.

It looks like several pieces of support for proper variadic types (i.e.
non-aliases, instances etc) are still missing, so I tried to fill in
something where I needed it for type aliases, but not everywhere, some
notable examples:
* Type variable bound checks for instances are still broken, see TODO
item in `semanal_typeargs.py`
* I think type argument count check is still broken for instances (I
think I fixed it for type aliases), there can be fewer than
`len(type_vars) - 1` type arguments, e.g. if one of them is an unpack.
* We should only prohibit multiple *variadic* unpacks in a type list,
multiple fixed length unpacks are fine (I think I fixed this both for
aliases and instances)

Btw I was thinking about an example below, what should we do in such
cases?
```python
from typing import Tuple, TypeVar
from typing_extensions import TypeVarTuple, Unpack

T = TypeVar("T")
S = TypeVar("S")
Ts = TypeVarTuple("Ts")

Alias = Tuple[T, S, Unpack[Ts], S]

def foo(*x: Unpack[Ts]) -> None:
    y: Alias[Unpack[Ts], int, str]
    reveal_type(y)  # <-- what is this type?

# Ts = () => Tuple[int, str, str]
# Ts = (bool) => Tuple[bool, int, str, int]
# Ts = (bool, float) => Tuple[bool, float, int, str, float]
```

Finally, I noticed there is already some code duplication, and I am not
improving it. I am open to suggestions on how to reduce the code
duplication.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] committed May 21, 2023
1 parent 391ed85 commit 0334ebc
Show file tree
Hide file tree
Showing 12 changed files with 507 additions and 61 deletions.
48 changes: 44 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,13 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
get_proper_types,
has_recursive_types,
is_named_instance,
split_with_prefix_and_suffix,
)
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
from mypy.typestate import type_state
Expand Down Expand Up @@ -4070,6 +4072,35 @@ class LongName(Generic[T]): ...
# The _SpecialForm type can be used in some runtime contexts (e.g. it may have __or__).
return self.named_type("typing._SpecialForm")

def split_for_callable(
self, t: CallableType, args: Sequence[Type], ctx: Context
) -> list[Type]:
"""Handle directly applying type arguments to a variadic Callable.
This is needed in situations where e.g. variadic class object appears in
runtime context. For example:
class C(Generic[T, Unpack[Ts]]): ...
x = C[int, str]()
We simply group the arguments that need to go into Ts variable into a TupleType,
similar to how it is done in other places using split_with_prefix_and_suffix().
"""
vars = t.variables
if not vars or not any(isinstance(v, TypeVarTupleType) for v in vars):
return list(args)

prefix = next(i for (i, v) in enumerate(vars) if isinstance(v, TypeVarTupleType))
suffix = len(vars) - prefix - 1
args = flatten_nested_tuples(args)
if len(args) < len(vars) - 1:
self.msg.incompatible_type_application(len(vars), len(args), ctx)
return [AnyType(TypeOfAny.from_error)] * len(vars)

tvt = vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
start, middle, end = split_with_prefix_and_suffix(tuple(args), prefix, suffix)
return list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)

def apply_type_arguments_to_callable(
self, tp: Type, args: Sequence[Type], ctx: Context
) -> Type:
Expand All @@ -4083,19 +4114,28 @@ def apply_type_arguments_to_callable(
tp = get_proper_type(tp)

if isinstance(tp, CallableType):
if len(tp.variables) != len(args):
if len(tp.variables) != len(args) and not any(
isinstance(v, TypeVarTupleType) for v in tp.variables
):
if tp.is_type_obj() and tp.type_object().fullname == "builtins.tuple":
# TODO: Specialize the callable for the type arguments
return tp
self.msg.incompatible_type_application(len(tp.variables), len(args), ctx)
return AnyType(TypeOfAny.from_error)
return self.apply_generic_arguments(tp, args, ctx)
return self.apply_generic_arguments(tp, self.split_for_callable(tp, args, ctx), ctx)
if isinstance(tp, Overloaded):
for it in tp.items:
if len(it.variables) != len(args):
if len(it.variables) != len(args) and not any(
isinstance(v, TypeVarTupleType) for v in it.variables
):
self.msg.incompatible_type_application(len(it.variables), len(args), ctx)
return AnyType(TypeOfAny.from_error)
return Overloaded([self.apply_generic_arguments(it, args, ctx) for it in tp.items])
return Overloaded(
[
self.apply_generic_arguments(it, self.split_for_callable(it, args, ctx), ctx)
for it in tp.items
]
)
return AnyType(TypeOfAny.special_form)

def visit_list_expr(self, e: ListExpr) -> Type:
Expand Down
14 changes: 5 additions & 9 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, List, Sequence
from typing import TYPE_CHECKING, Iterable, List, Sequence, cast
from typing_extensions import Final

import mypy.subtypes
Expand Down Expand Up @@ -46,15 +46,11 @@
has_recursive_types,
has_type_vars,
is_named_instance,
split_with_prefix_and_suffix,
)
from mypy.types_utils import is_union_with_any
from mypy.typestate import type_state
from mypy.typevartuples import (
extract_unpack,
find_unpack_in_list,
split_with_mapped_and_template,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import extract_unpack, find_unpack_in_list, split_with_mapped_and_template

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
Expand Down Expand Up @@ -669,7 +665,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
else:
mapped_args = mapped.args
instance_args = instance.args
Expand Down Expand Up @@ -738,7 +734,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
template.type.type_var_tuple_prefix,
template.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
else:
mapped_args = mapped.args
template_args = template.args
Expand Down
40 changes: 26 additions & 14 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
)
from mypy.typevartuples import (
find_unpack_in_list,
split_with_instance,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import find_unpack_in_list, split_with_instance

# WARNING: these functions should never (directly or indirectly) depend on
# is_subtype(), meet_types(), join_types() etc.
Expand Down Expand Up @@ -115,6 +113,7 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
instance_args = instance.args

for binder, arg in zip(tvars, instance_args):
assert isinstance(binder, TypeVarLikeType)
variables[binder.id] = arg

return expand_type(typ, variables)
Expand Down Expand Up @@ -282,12 +281,14 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
raise NotImplementedError

def visit_unpack_type(self, t: UnpackType) -> Type:
# It is impossible to reasonally implement visit_unpack_type, because
# It is impossible to reasonably implement visit_unpack_type, because
# unpacking inherently expands to something more like a list of types.
#
# Relevant sections that can call unpack should call expand_unpack()
# instead.
assert False, "Mypy bug: unpacking must happen at a higher level"
# However, if the item is a variadic tuple, we can simply carry it over.
# it is hard to assert this without getting proper type.
return UnpackType(t.type.accept(self))

def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
return expand_unpack_with_variables(t, self.variables)
Expand Down Expand Up @@ -356,7 +357,15 @@ def interpolate_args_for_unpack(

# Extract the typevartuple so we can get a tuple fallback from it.
expanded_unpacked_tvt = expanded_unpack.type
assert isinstance(expanded_unpacked_tvt, TypeVarTupleType)
if isinstance(expanded_unpacked_tvt, TypeVarTupleType):
fallback = expanded_unpacked_tvt.tuple_fallback
else:
# This can happen when tuple[Any, ...] is used to "patch" a variadic
# generic type without type arguments provided.
assert isinstance(expanded_unpacked_tvt, ProperType)
assert isinstance(expanded_unpacked_tvt, Instance)
assert expanded_unpacked_tvt.type.fullname == "builtins.tuple"
fallback = expanded_unpacked_tvt

prefix_len = expanded_unpack_index
arg_names = t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:]
Expand All @@ -368,11 +377,7 @@ def interpolate_args_for_unpack(
+ expanded_items[:prefix_len]
# Constructing the Unpack containing the tuple without the prefix.
+ [
UnpackType(
TupleType(
expanded_items[prefix_len:], expanded_unpacked_tvt.tuple_fallback
)
)
UnpackType(TupleType(expanded_items[prefix_len:], fallback))
if len(expanded_items) - prefix_len > 1
else expanded_items[0]
]
Expand Down Expand Up @@ -456,9 +461,12 @@ def expand_types_with_unpack(
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
# TODO: this will cause a crash on aliases like A = Tuple[int, Unpack[A]].
# Although it is unlikely anyone will write this, we should fail gracefully.
typs = flatten_nested_tuples(typs)
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType):
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
unpacked_items = self.expand_unpack(item)
if unpacked_items is None:
# TODO: better error, something like tuple of unknown?
Expand Down Expand Up @@ -523,7 +531,11 @@ def visit_type_type(self, t: TypeType) -> Type:
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the type alias cannot contain type variables (not bound by the type
# alias itself), so we just expand the arguments.
return t.copy_modified(args=self.expand_types(t.args))
args = self.expand_types_with_unpack(t.args)
if isinstance(args, list):
return t.copy_modified(args=args)
else:
return args

def expand_types(self, types: Iterable[Type]) -> list[Type]:
a: list[Type] = []
Expand Down
5 changes: 5 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3471,6 +3471,7 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here
"normalized",
"_is_recursive",
"eager",
"tvar_tuple_index",
)

__match_args__ = ("name", "target", "alias_tvars", "no_args")
Expand Down Expand Up @@ -3498,6 +3499,10 @@ def __init__(
# it is the cached value.
self._is_recursive: bool | None = None
self.eager = eager
self.tvar_tuple_index = None
for i, t in enumerate(alias_tvars):
if isinstance(t, mypy.types.TypeVarTupleType):
self.tvar_tuple_index = i
super().__init__(line, column)

@classmethod
Expand Down
13 changes: 12 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UnpackType,
Expand Down Expand Up @@ -3424,8 +3425,18 @@ def analyze_alias(
allowed_alias_tvars=tvar_defs,
)

# There can be only one variadic variable at most, the error is reported elsewhere.
new_tvar_defs = []
variadic = False
for td in tvar_defs:
if isinstance(td, TypeVarTupleType):
if variadic:
continue
variadic = True
new_tvar_defs.append(td)

qualified_tvars = [node.fullname for _name, node in found_type_vars]
return analyzed, tvar_defs, depends_on, qualified_tvars
return analyzed, new_tvar_defs, depends_on, qualified_tvars

def is_pep_613(self, s: AssignmentStmt) -> bool:
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
Expand Down
50 changes: 46 additions & 4 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mypy.options import Options
from mypy.scope import Scope
from mypy.subtypes import is_same_type, is_subtype
from mypy.typeanal import set_any_tvars
from mypy.types import (
AnyType,
Instance,
Expand All @@ -32,8 +33,10 @@
TypeVarType,
UnboundType,
UnpackType,
flatten_nested_tuples,
get_proper_type,
get_proper_types,
split_with_prefix_and_suffix,
)


Expand Down Expand Up @@ -79,10 +82,34 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None:
self.seen_aliases.add(t)
# Some recursive aliases may produce spurious args. In principle this is not very
# important, as we would simply ignore them when expanding, but it is better to keep
# correct aliases.
if t.alias and len(t.args) != len(t.alias.alias_tvars):
t.args = [AnyType(TypeOfAny.from_error) for _ in t.alias.alias_tvars]
# correct aliases. Also, variadic aliases are better to check when fully analyzed,
# so we do this here.
assert t.alias is not None, f"Unfixed type alias {t.type_ref}"
args = flatten_nested_tuples(t.args)
if t.alias.tvar_tuple_index is not None:
correct = len(args) >= len(t.alias.alias_tvars) - 1
if any(
isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance)
for a in args
):
correct = True
else:
correct = len(args) == len(t.alias.alias_tvars)
if not correct:
if t.alias.tvar_tuple_index is not None:
exp_len = f"at least {len(t.alias.alias_tvars) - 1}"
else:
exp_len = f"{len(t.alias.alias_tvars)}"
self.fail(
f"Bad number of arguments for type alias, expected: {exp_len}, given: {len(args)}",
t,
code=codes.TYPE_ARG,
)
t.args = set_any_tvars(
t.alias, t.line, t.column, self.options, from_error=True, fail=self.fail
).args
else:
t.args = args
is_error = self.validate_args(t.alias.name, t.args, t.alias.alias_tvars, t)
if not is_error:
# If there was already an error for the alias itself, there is no point in checking
Expand All @@ -101,6 +128,17 @@ def visit_instance(self, t: Instance) -> None:
def validate_args(
self, name: str, args: Sequence[Type], type_vars: list[TypeVarLikeType], ctx: Context
) -> bool:
# TODO: we need to do flatten_nested_tuples and validate arg count for instances
# similar to how do we do this for type aliases above, but this may have perf penalty.
if any(isinstance(v, TypeVarTupleType) for v in type_vars):
prefix = next(i for (i, v) in enumerate(type_vars) if isinstance(v, TypeVarTupleType))
tvt = type_vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
start, middle, end = split_with_prefix_and_suffix(
tuple(args), prefix, len(type_vars) - prefix - 1
)
args = list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)

is_error = False
for (i, arg), tvar in zip(enumerate(args), type_vars):
if isinstance(tvar, TypeVarType):
Expand Down Expand Up @@ -167,7 +205,11 @@ def visit_unpack_type(self, typ: UnpackType) -> None:
return
if isinstance(proper_type, Instance) and proper_type.type.fullname == "builtins.tuple":
return
if isinstance(proper_type, AnyType) and proper_type.type_of_any == TypeOfAny.from_error:
if (
isinstance(proper_type, UnboundType)
or isinstance(proper_type, AnyType)
and proper_type.type_of_any == TypeOfAny.from_error
):
return

# TODO: Infer something when it can't be unpacked to allow rest of
Expand Down
2 changes: 2 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool:
def visit_unpack_type(self, left: UnpackType) -> bool:
if isinstance(self.right, UnpackType):
return self._is_subtype(left.type, self.right.type)
if isinstance(self.right, Instance) and self.right.type.fullname == "builtins.object":
return True
return False

def visit_parameters(self, left: Parameters) -> bool:
Expand Down

0 comments on commit 0334ebc

Please sign in to comment.