Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exhaustiveness checking for match statements #12267

Merged
merged 23 commits into from Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 62 additions & 23 deletions mypy/checker.py
Expand Up @@ -4089,36 +4089,57 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
if isinstance(subject_type, DeletedType):
self.msg.deleted_as_rvalue(subject_type, s)

# We infer types of patterns twice. The first pass is used
# to infer the types of capture variables. The type of a
# capture variable may depend on multiple patterns (it
# will be a union of all capture types). This pass ignores
# guard expressions.
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]

type_maps: List[TypeMap] = [t.captures for t in pattern_types]
self.infer_variable_types_from_type_maps(type_maps)
inferred_types = self.infer_variable_types_from_type_maps(type_maps)

for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies):
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject,
subject_type)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
if b.is_unreachable or isinstance(get_proper_type(pattern_type.type),
UninhabitedType):
self.push_type_map(None)
else_map: TypeMap = {}
else:
self.binder.put(s.subject, pattern_type.type)
pattern_map, else_map = conditional_types_to_typemaps(
s.subject,
pattern_type.type,
pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures,
inferred_types)
self.push_type_map(pattern_map)
self.push_type_map(pattern_type.captures)
if g is not None:
gt = get_proper_type(self.expr_checker.accept(g))
with self.binder.frame_context(can_skip=True, fall_through=3):
gt = get_proper_type(self.expr_checker.accept(g))

if isinstance(gt, DeletedType):
self.msg.deleted_as_rvalue(gt, s)
if isinstance(gt, DeletedType):
self.msg.deleted_as_rvalue(gt, s)

if_map, _ = self.find_isinstance_check(g)
guard_map, guard_else_map = self.find_isinstance_check(g)
else_map = or_conditional_maps(else_map, guard_else_map)

self.push_type_map(if_map)
self.accept(b)
self.push_type_map(guard_map)
self.accept(b)
else:
self.accept(b)
self.push_type_map(else_map)

# This is needed due to a quirk in frame_context. Without it types will stay narrowed
# after the match.
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]:
all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
if tm is not None:
Expand All @@ -4128,28 +4149,38 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
assert isinstance(node, Var)
all_captures[node].append((expr, typ))

inferred_types: Dict[Var, Type] = {}
for var, captures in all_captures.items():
conflict = False
already_exists = False
types: List[Type] = []
for expr, typ in captures:
types.append(typ)

previous_type, _, inferred = self.check_lvalue(expr)
previous_type, _, _ = self.check_lvalue(expr)
if previous_type is not None:
conflict = True
self.check_subtype(typ, previous_type, expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type")
for type_map in type_maps:
if type_map is not None and expr in type_map:
del type_map[expr]

if not conflict:
already_exists = True
if self.check_subtype(typ, previous_type, expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type"):
inferred_types[var] = previous_type

if not already_exists:
new_type = UnionType.make_union(types)
# Infer the union type at the first occurrence
first_occurrence, _ = captures[0]
inferred_types[var] = new_type
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types

def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None:
if type_map:
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]

def make_fake_typeinfo(self,
curr_module_fullname: str,
Expand Down Expand Up @@ -5637,6 +5668,14 @@ def conditional_types(current_type: Type,
None means no new information can be inferred. If default is set it is returned
instead."""
if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (target.is_enum_literal()
or isinstance(target.value, bool)):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type,
enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
Expand Down
38 changes: 30 additions & 8 deletions mypy/checkpattern.py
@@ -1,4 +1,5 @@
"""Pattern checker. This file is conceptually part of TypeChecker."""

from collections import defaultdict
from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union
from typing_extensions import Final
Expand All @@ -19,7 +20,8 @@
)
from mypy.plugin import Plugin
from mypy.subtypes import is_subtype
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \
coerce_to_literal
from mypy.types import (
ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type,
TypedDictType, TupleType, NoneType, UnionType
Expand Down Expand Up @@ -55,7 +57,7 @@
'PatternType',
[
('type', Type), # The type the match subject can be narrowed to
('rest_type', Type), # For exhaustiveness checking. Not used yet
('rest_type', Type), # The remaining type if the pattern didn't match
('captures', Dict[Expression, Type]), # The variables captured by the pattern
])

Expand Down Expand Up @@ -177,6 +179,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
current_type = self.type_context[-1]
typ = self.chk.expr_checker.accept(o.expr)
typ = coerce_to_literal(typ)
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type,
[get_type_range(typ)],
Expand Down Expand Up @@ -259,6 +262,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types,
star_position,
len(inner_types))
rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types,
star_position,
len(inner_types))

#
# Calculate new type
Expand Down Expand Up @@ -287,15 +293,20 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:

if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
new_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(new_type)], o, default=current_type
)
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
new_inner_type = join_types(new_inner_type, typ)
new_type = self.construct_sequence_child(current_type, new_inner_type)
if not is_subtype(new_type, current_type):
if is_subtype(new_type, current_type):
new_type, _ = self.chk.conditional_types_with_intersection(
current_type,
[get_type_range(new_type)],
o,
default=current_type
)
else:
new_type = current_type
return PatternType(new_type, rest_type, captures)

Expand Down Expand Up @@ -344,8 +355,7 @@ def expand_starred_pattern_types(self,
star_pos: Optional[int],
num_types: int
) -> List[Type]:
"""
Undoes the contraction done by contract_starred_pattern_types.
"""Undoes the contraction done by contract_starred_pattern_types.

For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
to lenght 4 the result is [bool, int, int, str].
Expand Down Expand Up @@ -639,6 +649,13 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
For example:
construct_sequence_child(List[int], str) = List[str]
"""
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, UnionType):
types = [
self.construct_sequence_child(item, inner_type) for item in proper_type.items
if self.can_match_sequence(get_proper_type(item))
]
return make_simplified_union(types)
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
proper_type = get_proper_type(outer_type)
Expand Down Expand Up @@ -676,6 +693,11 @@ def get_var(expr: Expression) -> Var:


def get_type_range(typ: Type) -> 'mypy.checker.TypeRange':
typ = get_proper_type(typ)
if (isinstance(typ, Instance)
and typ.last_known_value
and isinstance(typ.last_known_value.value, bool)):
typ = typ.last_known_value
return mypy.checker.TypeRange(typ, is_upper_bound=False)


Expand Down
5 changes: 5 additions & 0 deletions mypy/patterns.py
Expand Up @@ -21,6 +21,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class AsPattern(Pattern):
"""The pattern <pattern> as <name>"""
# The python ast, and therefore also our ast merges capture, wildcard and as patterns into one
# for easier handling.
# If pattern is None this is a capture pattern. If name and pattern are both none this is a
Expand All @@ -39,6 +40,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class OrPattern(Pattern):
"""The pattern <pattern> | <pattern> | ..."""
patterns: List[Pattern]

def __init__(self, patterns: List[Pattern]) -> None:
Expand All @@ -50,6 +52,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class ValuePattern(Pattern):
"""The pattern x.y (or x.y.z, ...)"""
expr: Expression

def __init__(self, expr: Expression):
Expand All @@ -73,6 +76,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class SequencePattern(Pattern):
"""The pattern [<pattern>, ...]"""
patterns: List[Pattern]

def __init__(self, patterns: List[Pattern]):
Expand Down Expand Up @@ -114,6 +118,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class ClassPattern(Pattern):
"""The pattern Cls(...)"""
class_ref: RefExpr
positionals: List[Pattern]
keyword_keys: List[str]
Expand Down