diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bd69c1427dce1..90e71d8b7303d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3249,13 +3249,13 @@ def apply_type_arguments_to_callable( def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e.items, 'builtins.list', '', e) + return self.check_lst_expr(e, 'builtins.list', '') def visit_set_expr(self, e: SetExpr) -> Type: - return self.check_lst_expr(e.items, 'builtins.set', '', e) + return self.check_lst_expr(e, 'builtins.set', '') def fast_container_type( - self, items: List[Expression], container_fullname: str + self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str ) -> Optional[Type]: """ Fast path to determine the type of a list or set literal, @@ -3270,21 +3270,27 @@ def fast_container_type( ctx = self.type_context[-1] if ctx: return None + if e._resolved_type is not None: + return e._resolved_type if isinstance(e._resolved_type, Instance) else None values: List[Type] = [] - for item in items: + for item in e.items: if isinstance(item, StarExpr): # fallback to slow path + e._resolved_type = NoneType() return None values.append(self.accept(item)) vt = join.join_type_list(values) if not isinstance(vt, Instance): + e._resolved_type = NoneType() return None - return self.chk.named_generic_type(container_fullname, [vt]) + ct = self.chk.named_generic_type(container_fullname, [vt]) + e._resolved_type = ct + return ct - def check_lst_expr(self, items: List[Expression], fullname: str, - tag: str, context: Context) -> Type: + def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, + tag: str) -> Type: # fast path - t = self.fast_container_type(items, fullname) + t = self.fast_container_type(e, fullname) if t: return t @@ -3303,10 +3309,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str, variables=[tv]) out = self.check_call(constructor, [(i.expr if isinstance(i, StarExpr) else i) - for i in items], + for i in e.items], [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) - for i in items], - context)[0] + for i in e.items], + e)[0] return remove_instance_last_known_values(out) def visit_tuple_expr(self, e: TupleExpr) -> Type: @@ -3356,7 +3362,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: else: # A star expression that's not a Tuple. # Treat the whole thing as a variable-length tuple. - return self.check_lst_expr(e.items, 'builtins.tuple', '', e) + return self.check_lst_expr(e, 'builtins.tuple', '') else: if not type_context_items or j >= len(type_context_items): tt = self.accept(item) @@ -3382,6 +3388,8 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ctx = self.type_context[-1] if ctx: return None + if e._resolved_type is not None: + return e._resolved_type if isinstance(e._resolved_type, Instance) else None keys: List[Type] = [] values: List[Type] = [] stargs: Optional[Tuple[Type, Type]] = None @@ -3395,6 +3403,7 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: ): stargs = (st.args[0], st.args[1]) else: + e._resolved_type = NoneType() return None else: keys.append(self.accept(key)) @@ -3402,10 +3411,14 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: kt = join.join_type_list(keys) vt = join.join_type_list(values) if not (isinstance(kt, Instance) and isinstance(vt, Instance)): + e._resolved_type = NoneType() return None if stargs and (stargs[0] != kt or stargs[1] != vt): + e._resolved_type = NoneType() return None - return self.chk.named_generic_type('builtins.dict', [kt, vt]) + dt = self.chk.named_generic_type('builtins.dict', [kt, vt]) + e._resolved_type = dt + return dt def visit_dict_expr(self, e: DictExpr) -> Type: """Type check a dict expression. diff --git a/mypy/nodes.py b/mypy/nodes.py index 5a27783e97e1e..d8ae78759669c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2026,13 +2026,15 @@ def is_dynamic(self) -> bool: class ListExpr(Expression): """List literal expression [...].""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_list_expr(self) @@ -2041,13 +2043,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class DictExpr(Expression): """Dictionary literal expression {key: value, ...}.""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Tuple[Optional[Expression], Expression]] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Tuple[Optional[Expression], Expression]]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_dict_expr(self) @@ -2058,13 +2062,15 @@ class TupleExpr(Expression): Also lvalue sequences (..., ...) and [..., ...]""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_tuple_expr(self) @@ -2073,13 +2079,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class SetExpr(Expression): """Set literal expression {value, ...}.""" - __slots__ = ('items',) + __slots__ = ('items', '_resolved_type') items: List[Expression] + _resolved_type: Optional["mypy.types.ProperType"] def __init__(self, items: List[Expression]) -> None: super().__init__() self.items = items + self._resolved_type = None def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_set_expr(self)