Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkexpr: cache type of container literals when possible #12707

Merged
merged 1 commit into from May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/checker.py
Expand Up @@ -293,6 +293,7 @@ def reset(self) -> None:
self._type_maps[1:] = []
self._type_maps[0].clear()
self.temp_type_map = None
self.expr_checker.reset()

assert self.inferred_attribute_types is None
assert self.partial_types == []
Expand Down
49 changes: 36 additions & 13 deletions mypy/checkexpr.py
Expand Up @@ -177,6 +177,9 @@ class ExpressionChecker(ExpressionVisitor[Type]):
# Type context for type inference
type_context: List[Optional[Type]]

# cache resolved types in some cases
resolved_type: Dict[Expression, ProperType]

strfrm_checker: StringFormatterChecker
plugin: Plugin

Expand All @@ -197,6 +200,11 @@ def __init__(self,
self.type_overrides: Dict[Expression, Type] = {}
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)

self.resolved_type = {}

def reset(self) -> None:
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.

Expand Down Expand Up @@ -3269,13 +3277,13 @@ def apply_type_arguments_to_callable(

def visit_list_expr(self, e: ListExpr) -> Type:
"""Type check a list expression [...]."""
return self.check_lst_expr(e.items, 'builtins.list', '<list>', e)
return self.check_lst_expr(e, 'builtins.list', '<list>')

def visit_set_expr(self, e: SetExpr) -> Type:
return self.check_lst_expr(e.items, 'builtins.set', '<set>', e)
return self.check_lst_expr(e, 'builtins.set', '<set>')

def fast_container_type(
self, items: List[Expression], container_fullname: str
self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str
) -> Optional[Type]:
"""
Fast path to determine the type of a list or set literal,
Expand All @@ -3290,21 +3298,28 @@ def fast_container_type(
ctx = self.type_context[-1]
if ctx:
return None
rt = self.resolved_type.get(e, None)
if rt is not None:
return rt if isinstance(rt, Instance) else None
values: List[Type] = []
for item in items:
for item in e.items:
if isinstance(item, StarExpr):
# fallback to slow path
self.resolved_type[e] = NoneType()
return None
values.append(self.accept(item))
vt = join.join_type_list(values)
if not allow_fast_container_literal(vt):
self.resolved_type[e] = NoneType()
return None
return self.chk.named_generic_type(container_fullname, [vt])
ct = self.chk.named_generic_type(container_fullname, [vt])
self.resolved_type[e] = ct
return ct

def check_lst_expr(self, items: List[Expression], fullname: str,
tag: str, context: Context) -> Type:
def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str,
tag: str) -> Type:
# fast path
t = self.fast_container_type(items, fullname)
t = self.fast_container_type(e, fullname)
if t:
return t

Expand All @@ -3323,10 +3338,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str,
variables=[tv])
out = self.check_call(constructor,
[(i.expr if isinstance(i, StarExpr) else i)
for i in items],
for i in e.items],
[(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS)
for i in items],
context)[0]
for i in e.items],
e)[0]
return remove_instance_last_known_values(out)

def visit_tuple_expr(self, e: TupleExpr) -> Type:
Expand Down Expand Up @@ -3376,7 +3391,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
else:
# A star expression that's not a Tuple.
# Treat the whole thing as a variable-length tuple.
return self.check_lst_expr(e.items, 'builtins.tuple', '<tuple>', e)
return self.check_lst_expr(e, 'builtins.tuple', '<tuple>')
else:
if not type_context_items or j >= len(type_context_items):
tt = self.accept(item)
Expand All @@ -3402,6 +3417,9 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
ctx = self.type_context[-1]
if ctx:
return None
rt = self.resolved_type.get(e, None)
if rt is not None:
return rt if isinstance(rt, Instance) else None
keys: List[Type] = []
values: List[Type] = []
stargs: Optional[Tuple[Type, Type]] = None
Expand All @@ -3415,17 +3433,22 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
):
stargs = (st.args[0], st.args[1])
else:
self.resolved_type[e] = NoneType()
return None
else:
keys.append(self.accept(key))
values.append(self.accept(value))
kt = join.join_type_list(keys)
vt = join.join_type_list(values)
if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)):
self.resolved_type[e] = NoneType()
return None
if stargs and (stargs[0] != kt or stargs[1] != vt):
self.resolved_type[e] = NoneType()
return None
return self.chk.named_generic_type('builtins.dict', [kt, vt])
dt = self.chk.named_generic_type('builtins.dict', [kt, vt])
self.resolved_type[e] = dt
return dt

def visit_dict_expr(self, e: DictExpr) -> Type:
"""Type check a dict expression.
Expand Down