diff --git a/mypy/checker.py b/mypy/checker.py index 109a3b1f15d2..e5abcfcf4541 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -293,6 +293,7 @@ def reset(self) -> None: self._type_maps[1:] = [] self._type_maps[0].clear() self.temp_type_map = None + self.expr_checker.reset() assert self.inferred_attribute_types is None assert self.partial_types == [] diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bfbe961adc7a..05d5ea122a78 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -177,6 +177,9 @@ class ExpressionChecker(ExpressionVisitor[Type]): # Type context for type inference type_context: List[Optional[Type]] + # cache resolved types in some cases + resolved_type: Dict[Expression, ProperType] + strfrm_checker: StringFormatterChecker plugin: Plugin @@ -197,6 +200,11 @@ def __init__(self, self.type_overrides: Dict[Expression, Type] = {} self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) + self.resolved_type = {} + + def reset(self) -> None: + self.resolved_type = {} + def visit_name_expr(self, e: NameExpr) -> Type: """Type check a name expression. @@ -3269,13 +3277,13 @@ def apply_type_arguments_to_callable( def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e.items, 'builtins.list', '', e) + return self.check_lst_expr(e, 'builtins.list', '') def visit_set_expr(self, e: SetExpr) -> Type: - return self.check_lst_expr(e.items, 'builtins.set', '', e) + return self.check_lst_expr(e, 'builtins.set', '') def fast_container_type( - self, items: List[Expression], container_fullname: str + self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str ) -> Optional[Type]: """ Fast path to determine the type of a list or set literal, @@ -3290,21 +3298,28 @@ def fast_container_type( ctx = self.type_context[-1] if ctx: return None + rt = self.resolved_type.get(e, None) + if rt is not None: + return rt if isinstance(rt, Instance) else None values: List[Type] = [] - for item in items: + for item in e.items: if isinstance(item, StarExpr): # fallback to slow path + self.resolved_type[e] = NoneType() return None values.append(self.accept(item)) vt = join.join_type_list(values) if not allow_fast_container_literal(vt): + self.resolved_type[e] = NoneType() return None - return self.chk.named_generic_type(container_fullname, [vt]) + ct = self.chk.named_generic_type(container_fullname, [vt]) + self.resolved_type[e] = ct + return ct - def check_lst_expr(self, items: List[Expression], fullname: str, - tag: str, context: Context) -> Type: + def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, + tag: str) -> Type: # fast path - t = self.fast_container_type(items, fullname) + t = self.fast_container_type(e, fullname) if t: return t @@ -3323,10 +3338,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str, variables=[tv]) out = self.check_call(constructor, [(i.expr if isinstance(i, StarExpr) else i) - for i in items], + for i in e.items], [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) - for i in items], - context)[0] + for i in e.items], + e)[0] return remove_instance_last_known_values(out) def visit_tuple_expr(self, e: TupleExpr) -> Type: @@ -3376,7 +3391,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: else: # A star expression that's not a Tuple. # Treat the whole thing as a variable-length tuple. - return self.check_lst_expr(e.items, 'builtins.tuple', '', e) + return self.check_lst_expr(e, 'builtins.tuple', '') else: if not type_context_items or j >= len(type_context_items): tt = self.accept(item) @@ -3402,6 +3417,9 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ctx = self.type_context[-1] if ctx: return None + rt = self.resolved_type.get(e, None) + if rt is not None: + return rt if isinstance(rt, Instance) else None keys: List[Type] = [] values: List[Type] = [] stargs: Optional[Tuple[Type, Type]] = None @@ -3415,6 +3433,7 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ): stargs = (st.args[0], st.args[1]) else: + self.resolved_type[e] = NoneType() return None else: keys.append(self.accept(key)) @@ -3422,10 +3441,14 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: kt = join.join_type_list(keys) vt = join.join_type_list(values) if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)): + self.resolved_type[e] = NoneType() return None if stargs and (stargs[0] != kt or stargs[1] != vt): + self.resolved_type[e] = NoneType() return None - return self.chk.named_generic_type('builtins.dict', [kt, vt]) + dt = self.chk.named_generic_type('builtins.dict', [kt, vt]) + self.resolved_type[e] = dt + return dt def visit_dict_expr(self, e: DictExpr) -> Type: """Type check a dict expression.