Skip to content

Commit

Permalink
checkexpr: cache type of container literals when possible
Browse files Browse the repository at this point in the history
When a container (list, set, tuple, or dict) literal expression is
used as an argument to an overloaded function it will get repeatedly
typechecked. This becomes particularly problematic when the expression
is somewhat large, as seen in #9427

To avoid repeated work, add a new cache in ExprChecker, mapping the AST
node to the resolved type of the expression. Right now the cache is
only used in the fast path, although it could conceivably be leveraged
for the slow path as well in a follow-up commit.

To further reduce duplicate work, when the fast-path doesn't work, we
use the cache to make a note of that, to avoid repeatedly attempting to
take the fast path.

Fixes #9427
  • Loading branch information
huguesb committed May 16, 2022
1 parent 7fbf4de commit af675fe
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
1 change: 1 addition & 0 deletions mypy/checker.py
Expand Up @@ -293,6 +293,7 @@ def reset(self) -> None:
self._type_maps[1:] = []
self._type_maps[0].clear()
self.temp_type_map = None
self.expr_checker.reset()

assert self.inferred_attribute_types is None
assert self.partial_types == []
Expand Down
49 changes: 36 additions & 13 deletions mypy/checkexpr.py
Expand Up @@ -177,6 +177,9 @@ class ExpressionChecker(ExpressionVisitor[Type]):
# Type context for type inference
type_context: List[Optional[Type]]

# cache resolved types in some cases
resolved_type: Dict[Expression, ProperType]

strfrm_checker: StringFormatterChecker
plugin: Plugin

Expand All @@ -197,6 +200,11 @@ def __init__(self,
self.type_overrides: Dict[Expression, Type] = {}
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)

self.resolved_type = {}

def reset(self) -> None:
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -3269,13 +3277,13 @@ def apply_type_arguments_to_callable(

def visit_list_expr(self, e: ListExpr) -> Type:
"""Type check a list expression [...]."""
return self.check_lst_expr(e.items, 'builtins.list', '<list>', e)
return self.check_lst_expr(e, 'builtins.list', '<list>')

def visit_set_expr(self, e: SetExpr) -> Type:
return self.check_lst_expr(e.items, 'builtins.set', '<set>', e)
return self.check_lst_expr(e, 'builtins.set', '<set>')

def fast_container_type(
self, items: List[Expression], container_fullname: str
self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str
) -> Optional[Type]:
"""
Fast path to determine the type of a list or set literal,
Expand All @@ -3290,21 +3298,28 @@ def fast_container_type(
ctx = self.type_context[-1]
if ctx:
return None
rt = self.resolved_type.get(e, None)
if rt is not None:
return rt if isinstance(rt, Instance) else None
values: List[Type] = []
for item in items:
for item in e.items:
if isinstance(item, StarExpr):
# fallback to slow path
self.resolved_type[e] = NoneType()
return None
values.append(self.accept(item))
vt = join.join_type_list(values)
if not allow_fast_container_literal(vt):
self.resolved_type[e] = NoneType()
return None
return self.chk.named_generic_type(container_fullname, [vt])
ct = self.chk.named_generic_type(container_fullname, [vt])
self.resolved_type[e] = ct
return ct

def check_lst_expr(self, items: List[Expression], fullname: str,
tag: str, context: Context) -> Type:
def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str,
tag: str) -> Type:
# fast path
t = self.fast_container_type(items, fullname)
t = self.fast_container_type(e, fullname)
if t:
return t

Expand All @@ -3323,10 +3338,10 @@ def check_lst_expr(self, items: List[Expression], fullname: str,
variables=[tv])
out = self.check_call(constructor,
[(i.expr if isinstance(i, StarExpr) else i)
for i in items],
for i in e.items],
[(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS)
for i in items],
context)[0]
for i in e.items],
e)[0]
return remove_instance_last_known_values(out)

def visit_tuple_expr(self, e: TupleExpr) -> Type:
Expand Down Expand Up @@ -3376,7 +3391,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
else:
# A star expression that's not a Tuple.
# Treat the whole thing as a variable-length tuple.
return self.check_lst_expr(e.items, 'builtins.tuple', '<tuple>', e)
return self.check_lst_expr(e, 'builtins.tuple', '<tuple>')
else:
if not type_context_items or j >= len(type_context_items):
tt = self.accept(item)
Expand All @@ -3402,6 +3417,9 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
ctx = self.type_context[-1]
if ctx:
return None
rt = self.resolved_type.get(e, None)
if rt is not None:
return rt if isinstance(rt, Instance) else None
keys: List[Type] = []
values: List[Type] = []
stargs: Optional[Tuple[Type, Type]] = None
Expand All @@ -3415,17 +3433,22 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]:
):
stargs = (st.args[0], st.args[1])
else:
self.resolved_type[e] = NoneType()
return None
else:
keys.append(self.accept(key))
values.append(self.accept(value))
kt = join.join_type_list(keys)
vt = join.join_type_list(values)
if not (allow_fast_container_literal(kt) and allow_fast_container_literal(vt)):
self.resolved_type[e] = NoneType()
return None
if stargs and (stargs[0] != kt or stargs[1] != vt):
self.resolved_type[e] = NoneType()
return None
return self.chk.named_generic_type('builtins.dict', [kt, vt])
dt = self.chk.named_generic_type('builtins.dict', [kt, vt])
self.resolved_type[e] = dt
return dt

def visit_dict_expr(self, e: DictExpr) -> Type:
"""Type check a dict expression.
Expand Down

0 comments on commit af675fe

Please sign in to comment.