From ac1137462a02eb27ffa20a484ee3cab234dbb392 Mon Sep 17 00:00:00 2001 From: Hugues Bruant Date: Sun, 1 May 2022 00:16:10 -0700 Subject: [PATCH] checkexpr: cache type of container literals when possible When a container (list, set, tuple, or dict) literal expression is used as an argument to an overloaded function it will get repeatedly typechecked. This becomes particularly problematic when the expression is somewhat large, as seen in #9427 To avoid repeated work, add a new field in the relevant AST nodes to cache the resolved type of the expression. Right now the cache is only used in the fast path, although it could conceivably be leveraged for the slow path as well in a follow-up commit. To further reduce duplicate work, when the fast-path doesn't work, we use the cache to make a note of that, to avoid repeatedly attempting to take the fast path. Fixes #9427 --- mypy/checkexpr.py | 39 ++++++++++++++++++++++++++------------- mypy/nodes.py | 16 ++++++++++++---- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9dfc0e2a64587..cc788170f7d5e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3259,13 +3259,13 @@ def apply_type_arguments_to_callable( def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e.items, 'builtins.list', '', e) + return self.check_lst_expr(e, 'builtins.list', '') def visit_set_expr(self, e: SetExpr) -> Type: - return self.check_lst_expr(e.items, 'builtins.set', '', e) + return self.check_lst_expr(e, 'builtins.set', '') def fast_container_type( - self, items: List[Expression], container_fullname: str + self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str ) -> Optional[Type]: """ Fast path to determine the type of a list or set literal, @@ -3280,21 +3280,27 @@ def fast_container_type( ctx = self.type_context[-1] if ctx: return None + if e._resolved_type is not None: + return e._resolved_type if isinstance(e._resolved_type, Instance) else None values: List[Type] = [] - for item in items: + for item in e.items: if isinstance(item, StarExpr): # fallback to slow path + e._resolved_type = NoneType() return None values.append(self.accept(item)) vt = join.join_type_list(values) if not allow_fast_container_literal(vt): + e._resolved_type = NoneType() return None - return self.chk.named_generic_type(container_fullname, [vt]) + ct = self.chk.named_generic_type(container_fullname, [vt]) + e._resolved_type = ct + return ct - def check_lst_expr(self, items: List[Expression], fullname: str, - tag: str, context: Context) -> Type: + def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, + tag: str) -> Type: # fast path - t = self.fast_container_type(items, fullname) + t = self.fast_container_type(e, fullname) if t: return t @@ -3313,10 +3319,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str, variables=[tv]) out = self.check_call(constructor, [(i.expr if isinstance(i, StarExpr) else i) - for i in items], + for i in e.items], [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) - for i in items], - context)[0] + for i in e.items], + e)[0] return remove_instance_last_known_values(out) def visit_tuple_expr(self, e: TupleExpr) -> Type: @@ -3366,7 +3372,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: else: # A star expression that's not a Tuple. # Treat the whole thing as a variable-length tuple. - return self.check_lst_expr(e.items, 'builtins.tuple', '', e) + return self.check_lst_expr(e, 'builtins.tuple', '') else: if not type_context_items or j >= len(type_context_items): tt = self.accept(item) @@ -3392,6 +3398,8 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ctx = self.type_context[-1] if ctx: return None + if e._resolved_type is not None: + return e._resolved_type if isinstance(e._resolved_type, Instance) else None keys: List[Type] = [] values: List[Type] = [] stargs: Optional[Tuple[Type, Type]] = None @@ -3405,6 +3413,7 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ): stargs = (st.args[0], st.args[1]) else: + e._resolved_type = NoneType() return None else: keys.append(self.accept(key)) @@ -3412,10 +3421,14 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: kt = join.join_type_list(keys) vt = join.join_type_list(values) if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)): + e._resolved_type = NoneType() return None if stargs and (stargs[0] != kt or stargs[1] != vt): + e._resolved_type = NoneType() return None - return self.chk.named_generic_type('builtins.dict', [kt, vt]) + dt = self.chk.named_generic_type('builtins.dict', [kt, vt]) + e._resolved_type = dt + return dt def visit_dict_expr(self, e: DictExpr) -> Type: """Type check a dict expression. diff --git a/mypy/nodes.py b/mypy/nodes.py index 4ffa3116a1189..45c3e59b23fc5 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2026,13 +2026,15 @@ def is_dynamic(self) -> bool: class ListExpr(Expression): """List literal expression [...].""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_list_expr(self) @@ -2041,13 +2043,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class DictExpr(Expression): """Dictionary literal expression {key: value, ...}.""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Tuple[Optional[Expression], Expression]] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Tuple[Optional[Expression], Expression]]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_dict_expr(self) @@ -2058,13 +2062,15 @@ class TupleExpr(Expression): Also lvalue sequences (..., ...) and [..., ...]""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_tuple_expr(self) @@ -2073,13 +2079,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class SetExpr(Expression): """Set literal expression {value, ...}.""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_set_expr(self)