Skip to content

Commit

Permalink
Speed up finding function type variables (#16562)
Browse files Browse the repository at this point in the history
Merge two visitors into a single visitor that is a bit more optimized
than the old visitors.

This speeds ups tests, in particular -- `mypy/test/testcheck.py` is
about 4% faster and `mypy/test/testpythoneval.py` is about 3% faster.

Also self-check is about 1% faster, both interpreted and compiled.

This adds more code, but the new code is largely boilerplate, so the
difficulty of maintenance seems roughly the same.
  • Loading branch information
JukkaL committed Dec 28, 2023
1 parent f79ae69 commit 761965d
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 89 deletions.
13 changes: 9 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@
from mypy.tvar_scope import TypeVarLikeScope
from mypy.typeanal import (
SELF_TYPE_NAMES,
FindTypeVarVisitor,
TypeAnalyser,
TypeVarLikeList,
TypeVarLikeQuery,
analyze_type_alias,
check_for_explicit_any,
detect_diverging_alias,
Expand Down Expand Up @@ -2034,6 +2034,11 @@ def analyze_unbound_tvar_impl(
assert isinstance(sym.node, TypeVarExpr)
return t.name, sym.node

def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
visitor = FindTypeVarVisitor(self, self.tvar_scope)
t.accept(visitor)
return visitor.type_var_likes

def get_all_bases_tvars(
self, base_type_exprs: list[Expression], removed: list[int]
) -> TypeVarLikeList:
Expand All @@ -2046,7 +2051,7 @@ def get_all_bases_tvars(
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
base_tvars = self.find_type_var_likes(base)
tvars.extend(base_tvars)
return remove_dups(tvars)

Expand All @@ -2064,7 +2069,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
base_tvars = self.find_type_var_likes(base)
tvars.extend(base_tvars)
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
tvar_defs = []
Expand Down Expand Up @@ -3490,7 +3495,7 @@ def analyze_alias(
)
return None, [], set(), [], False

found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope))
found_type_vars = self.find_type_var_likes(typ)
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
Expand Down
252 changes: 167 additions & 85 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,32 +1570,32 @@ def tvar_scope_frame(self) -> Iterator[None]:
yield
self.tvar_scope = old_scope

def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList:
return t.accept(
TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables)
)

def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]:
"""Return list of unique type variables referred to in a callable."""
names: list[str] = []
tvars: list[TypeVarLikeExpr] = []
def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
t.accept(visitor)
return visitor.type_var_likes

def infer_type_variables(
self, type: CallableType
) -> tuple[list[tuple[str, TypeVarLikeExpr]], bool]:
"""Infer type variables from a callable.
Return tuple with these items:
- list of unique type variables referred to in a callable
- whether there is a reference to the Self type
"""
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
for arg in type.arg_types:
for name, tvar_expr in self.find_type_var_likes(arg):
if name not in names:
names.append(name)
tvars.append(tvar_expr)
arg.accept(visitor)

# When finding type variables in the return type of a function, don't
# look inside Callable types. Type variables only appearing in
# functions in the return type belong to those functions, not the
# function we're currently analyzing.
for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False):
if name not in names:
names.append(name)
tvars.append(tvar_expr)
visitor.include_callables = False
type.ret_type.accept(visitor)

if not names:
return [] # Fast path
return list(zip(names, tvars))
return visitor.type_var_likes, visitor.has_self_type

def bind_function_type_variables(
self, fun_type: CallableType, defn: Context
Expand All @@ -1615,10 +1615,7 @@ def bind_function_type_variables(
binding = self.tvar_scope.bind_new(var.name, var_expr)
defs.append(binding)
return defs, has_self_type
typevars = self.infer_type_variables(fun_type)
has_self_type = find_self_type(
fun_type, lambda name: self.api.lookup_qualified(name, defn, suppress_errors=True)
)
typevars, has_self_type = self.infer_type_variables(fun_type)
# Do not define a new type variable if already defined in scope.
typevars = [
(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)
Expand Down Expand Up @@ -2062,67 +2059,6 @@ def flatten_tvars(lists: list[list[T]]) -> list[T]:
return result


class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]):
"""Find TypeVar and ParamSpec references in an unbound type."""

def __init__(
self,
api: SemanticAnalyzerCoreInterface,
scope: TypeVarLikeScope,
*,
include_callables: bool = True,
) -> None:
super().__init__(flatten_tvars)
self.api = api
self.scope = scope
self.include_callables = include_callables
# Only include type variables in type aliases args. This would be anyway
# that case if we expand (as target variables would be overridden with args)
# and it may cause infinite recursion on invalid (diverging) recursive aliases.
self.skip_alias_target = True

def _seems_like_callable(self, type: UnboundType) -> bool:
if not type.args:
return False
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))

def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList:
name = t.name
node = None
# Special case P.args and P.kwargs for ParamSpecs only.
if name.endswith("args"):
if name.endswith(".args") or name.endswith(".kwargs"):
base = ".".join(name.split(".")[:-1])
n = self.api.lookup_qualified(base, t)
if n is not None and isinstance(n.node, ParamSpecExpr):
node = n
name = base
if node is None:
node = self.api.lookup_qualified(name, t)
if (
node
and isinstance(node.node, TypeVarLikeExpr)
and self.scope.get_binding(node) is None
):
assert isinstance(node.node, TypeVarLikeExpr)
return [(name, node.node)]
elif not self.include_callables and self._seems_like_callable(t):
return []
elif node and node.fullname in LITERAL_TYPE_NAMES:
return []
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
# Don't query the second argument to Annotated for TypeVars
return self.query_types([t.args[0]])
else:
return super().visit_unbound_type(t)

def visit_callable_type(self, t: CallableType) -> TypeVarLikeList:
if self.include_callables:
return super().visit_callable_type(t)
else:
return []


class DivergingAliasDetector(TrivialSyntheticTypeTranslator):
"""See docstring of detect_diverging_alias() for details."""

Expand Down Expand Up @@ -2359,3 +2295,149 @@ def unknown_unpack(t: Type) -> bool:
if isinstance(unpacked, AnyType) and unpacked.type_of_any == TypeOfAny.special_form:
return True
return False


class FindTypeVarVisitor(SyntheticTypeVisitor[None]):
"""Type visitor that looks for type variable types and self types."""

def __init__(self, api: SemanticAnalyzerCoreInterface, scope: TypeVarLikeScope) -> None:
self.api = api
self.scope = scope
self.type_var_likes: list[tuple[str, TypeVarLikeExpr]] = []
self.has_self_type = False
self.seen_aliases: set[TypeAliasType] | None = None
self.include_callables = True

def _seems_like_callable(self, type: UnboundType) -> bool:
if not type.args:
return False
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))

def visit_unbound_type(self, t: UnboundType) -> None:
name = t.name
node = None

# Special case P.args and P.kwargs for ParamSpecs only.
if name.endswith("args"):
if name.endswith(".args") or name.endswith(".kwargs"):
base = ".".join(name.split(".")[:-1])
n = self.api.lookup_qualified(base, t)
if n is not None and isinstance(n.node, ParamSpecExpr):
node = n
name = base
if node is None:
node = self.api.lookup_qualified(name, t)
if node and node.fullname in SELF_TYPE_NAMES:
self.has_self_type = True
if (
node
and isinstance(node.node, TypeVarLikeExpr)
and self.scope.get_binding(node) is None
):
if (name, node.node) not in self.type_var_likes:
self.type_var_likes.append((name, node.node))
elif not self.include_callables and self._seems_like_callable(t):
if find_self_type(
t, lambda name: self.api.lookup_qualified(name, t, suppress_errors=True)
):
self.has_self_type = True
return
elif node and node.fullname in LITERAL_TYPE_NAMES:
return
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
# Don't query the second argument to Annotated for TypeVars
self.process_types([t.args[0]])
elif t.args:
self.process_types(t.args)

def visit_type_list(self, t: TypeList) -> None:
self.process_types(t.items)

def visit_callable_argument(self, t: CallableArgument) -> None:
t.typ.accept(self)

def visit_any(self, t: AnyType) -> None:
pass

def visit_uninhabited_type(self, t: UninhabitedType) -> None:
pass

def visit_none_type(self, t: NoneType) -> None:
pass

def visit_erased_type(self, t: ErasedType) -> None:
pass

def visit_deleted_type(self, t: DeletedType) -> None:
pass

def visit_type_var(self, t: TypeVarType) -> None:
self.process_types([t.upper_bound, t.default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> None:
self.process_types([t.upper_bound, t.default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
self.process_types([t.upper_bound, t.default])

def visit_unpack_type(self, t: UnpackType) -> None:
self.process_types([t.type])

def visit_parameters(self, t: Parameters) -> None:
self.process_types(t.arg_types)

def visit_partial_type(self, t: PartialType) -> None:
pass

def visit_instance(self, t: Instance) -> None:
self.process_types(t.args)

def visit_callable_type(self, t: CallableType) -> None:
# FIX generics
self.process_types(t.arg_types)
t.ret_type.accept(self)

def visit_tuple_type(self, t: TupleType) -> None:
self.process_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> None:
self.process_types(list(t.items.values()))

def visit_raw_expression_type(self, t: RawExpressionType) -> None:
pass

def visit_literal_type(self, t: LiteralType) -> None:
pass

def visit_union_type(self, t: UnionType) -> None:
self.process_types(t.items)

def visit_overloaded(self, t: Overloaded) -> None:
self.process_types(t.items) # type: ignore[arg-type]

def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)

def visit_ellipsis_type(self, t: EllipsisType) -> None:
pass

def visit_placeholder_type(self, t: PlaceholderType) -> None:
return self.process_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
# Skip type aliases in already visited types to avoid infinite recursion.
if self.seen_aliases is None:
self.seen_aliases = set()
elif t in self.seen_aliases:
return
self.seen_aliases.add(t)
self.process_types(t.args)

def process_types(self, types: list[Type] | tuple[Type, ...]) -> None:
# Redundant type check helps mypyc.
if isinstance(types, list):
for t in types:
t.accept(self)
else:
for t in types:
t.accept(self)

0 comments on commit 761965d

Please sign in to comment.