Skip to content

Commit

Permalink
Exhaustiveness checking for match statements (python#12267)
Browse files Browse the repository at this point in the history
Closes python#12010.

Mypy can now detect if a match statement covers all the possible values.
Example:
```
def f(x: int | str) -> int:
    match x:
        case str():
            return 0
        case int():
            return 1
    # Mypy knows that we can't reach here
```

Most of the work was done by @freundTech. I did various minor updates
and changes to tests.

This doesn't handle some cases properly, including these:
1. We don't recognize that `match [*args]` fully covers a list type
2. Fake intersections don't work quite right (some tests are skipped)
3. We assume enums don't have custom `__eq__` methods

Co-authored-by: Adrian Freund <adrian@freund.io>
  • Loading branch information
2 people authored and cdce8p committed Mar 10, 2022
1 parent 0fff609 commit 7d81e4e
Show file tree
Hide file tree
Showing 4 changed files with 520 additions and 142 deletions.
85 changes: 62 additions & 23 deletions mypy/checker.py
Expand Up @@ -4089,36 +4089,57 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
if isinstance(subject_type, DeletedType):
self.msg.deleted_as_rvalue(subject_type, s)

# We infer types of patterns twice. The first pass is used
# to infer the types of capture variables. The type of a
# capture variable may depend on multiple patterns (it
# will be a union of all capture types). This pass ignores
# guard expressions.
pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns]

type_maps: List[TypeMap] = [t.captures for t in pattern_types]
self.infer_variable_types_from_type_maps(type_maps)
inferred_types = self.infer_variable_types_from_type_maps(type_maps)

for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies):
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject,
subject_type)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
if b.is_unreachable or isinstance(get_proper_type(pattern_type.type),
UninhabitedType):
self.push_type_map(None)
else_map: TypeMap = {}
else:
self.binder.put(s.subject, pattern_type.type)
pattern_map, else_map = conditional_types_to_typemaps(
s.subject,
pattern_type.type,
pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures,
inferred_types)
self.push_type_map(pattern_map)
self.push_type_map(pattern_type.captures)
if g is not None:
gt = get_proper_type(self.expr_checker.accept(g))
with self.binder.frame_context(can_skip=True, fall_through=3):
gt = get_proper_type(self.expr_checker.accept(g))

if isinstance(gt, DeletedType):
self.msg.deleted_as_rvalue(gt, s)
if isinstance(gt, DeletedType):
self.msg.deleted_as_rvalue(gt, s)

if_map, _ = self.find_isinstance_check(g)
guard_map, guard_else_map = self.find_isinstance_check(g)
else_map = or_conditional_maps(else_map, guard_else_map)

self.push_type_map(if_map)
self.accept(b)
self.push_type_map(guard_map)
self.accept(b)
else:
self.accept(b)
self.push_type_map(else_map)

# This is needed due to a quirk in frame_context. Without it types will stay narrowed
# after the match.
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[Var, Type]:
all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
if tm is not None:
Expand All @@ -4128,28 +4149,38 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
assert isinstance(node, Var)
all_captures[node].append((expr, typ))

inferred_types: Dict[Var, Type] = {}
for var, captures in all_captures.items():
conflict = False
already_exists = False
types: List[Type] = []
for expr, typ in captures:
types.append(typ)

previous_type, _, inferred = self.check_lvalue(expr)
previous_type, _, _ = self.check_lvalue(expr)
if previous_type is not None:
conflict = True
self.check_subtype(typ, previous_type, expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type")
for type_map in type_maps:
if type_map is not None and expr in type_map:
del type_map[expr]

if not conflict:
already_exists = True
if self.check_subtype(typ, previous_type, expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type"):
inferred_types[var] = previous_type

if not already_exists:
new_type = UnionType.make_union(types)
# Infer the union type at the first occurrence
first_occurrence, _ = captures[0]
inferred_types[var] = new_type
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types

def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, Type]) -> None:
if type_map:
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]

def make_fake_typeinfo(self,
curr_module_fullname: str,
Expand Down Expand Up @@ -5638,6 +5669,14 @@ def conditional_types(current_type: Type,
None means no new information can be inferred. If default is set it is returned
instead."""
if proposed_type_ranges:
if len(proposed_type_ranges) == 1:
target = proposed_type_ranges[0].item
target = get_proper_type(target)
if isinstance(target, LiteralType) and (target.is_enum_literal()
or isinstance(target.value, bool)):
enum_name = target.fallback.type.fullname
current_type = try_expanding_sum_type_to_union(current_type,
enum_name)
proposed_items = [type_range.item for type_range in proposed_type_ranges]
proposed_type = make_simplified_union(proposed_items)
if isinstance(proposed_type, AnyType):
Expand Down
38 changes: 30 additions & 8 deletions mypy/checkpattern.py
@@ -1,4 +1,5 @@
"""Pattern checker. This file is conceptually part of TypeChecker."""

from collections import defaultdict
from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union
from typing_extensions import Final
Expand All @@ -19,7 +20,8 @@
)
from mypy.plugin import Plugin
from mypy.subtypes import is_subtype
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union
from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \
coerce_to_literal
from mypy.types import (
ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type,
TypedDictType, TupleType, NoneType, UnionType
Expand Down Expand Up @@ -55,7 +57,7 @@
'PatternType',
[
('type', Type), # The type the match subject can be narrowed to
('rest_type', Type), # For exhaustiveness checking. Not used yet
('rest_type', Type), # The remaining type if the pattern didn't match
('captures', Dict[Expression, Type]), # The variables captured by the pattern
])

Expand Down Expand Up @@ -177,6 +179,7 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
current_type = self.type_context[-1]
typ = self.chk.expr_checker.accept(o.expr)
typ = coerce_to_literal(typ)
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type,
[get_type_range(typ)],
Expand Down Expand Up @@ -259,6 +262,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types,
star_position,
len(inner_types))
rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types,
star_position,
len(inner_types))

#
# Calculate new type
Expand Down Expand Up @@ -287,15 +293,20 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:

if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
new_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(new_type)], o, default=current_type
)
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
new_inner_type = join_types(new_inner_type, typ)
new_type = self.construct_sequence_child(current_type, new_inner_type)
if not is_subtype(new_type, current_type):
if is_subtype(new_type, current_type):
new_type, _ = self.chk.conditional_types_with_intersection(
current_type,
[get_type_range(new_type)],
o,
default=current_type
)
else:
new_type = current_type
return PatternType(new_type, rest_type, captures)

Expand Down Expand Up @@ -344,8 +355,7 @@ def expand_starred_pattern_types(self,
star_pos: Optional[int],
num_types: int
) -> List[Type]:
"""
Undoes the contraction done by contract_starred_pattern_types.
"""Undoes the contraction done by contract_starred_pattern_types.
For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
to lenght 4 the result is [bool, int, int, str].
Expand Down Expand Up @@ -639,6 +649,13 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
For example:
construct_sequence_child(List[int], str) = List[str]
"""
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, UnionType):
types = [
self.construct_sequence_child(item, inner_type) for item in proper_type.items
if self.can_match_sequence(get_proper_type(item))
]
return make_simplified_union(types)
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
proper_type = get_proper_type(outer_type)
Expand Down Expand Up @@ -676,6 +693,11 @@ def get_var(expr: Expression) -> Var:


def get_type_range(typ: Type) -> 'mypy.checker.TypeRange':
typ = get_proper_type(typ)
if (isinstance(typ, Instance)
and typ.last_known_value
and isinstance(typ.last_known_value.value, bool)):
typ = typ.last_known_value
return mypy.checker.TypeRange(typ, is_upper_bound=False)


Expand Down
5 changes: 5 additions & 0 deletions mypy/patterns.py
Expand Up @@ -21,6 +21,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class AsPattern(Pattern):
"""The pattern <pattern> as <name>"""
# The python ast, and therefore also our ast merges capture, wildcard and as patterns into one
# for easier handling.
# If pattern is None this is a capture pattern. If name and pattern are both none this is a
Expand All @@ -39,6 +40,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class OrPattern(Pattern):
"""The pattern <pattern> | <pattern> | ..."""
patterns: List[Pattern]

def __init__(self, patterns: List[Pattern]) -> None:
Expand All @@ -50,6 +52,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class ValuePattern(Pattern):
"""The pattern x.y (or x.y.z, ...)"""
expr: Expression

def __init__(self, expr: Expression):
Expand All @@ -73,6 +76,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class SequencePattern(Pattern):
"""The pattern [<pattern>, ...]"""
patterns: List[Pattern]

def __init__(self, patterns: List[Pattern]):
Expand Down Expand Up @@ -114,6 +118,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T:


class ClassPattern(Pattern):
"""The pattern Cls(...)"""
class_ref: RefExpr
positionals: List[Pattern]
keyword_keys: List[str]
Expand Down

0 comments on commit 7d81e4e

Please sign in to comment.