From 8454267e1f84be3e3b2a9c95cf9c6dd810769b93 Mon Sep 17 00:00:00 2001 From: Edward Paget Date: Fri, 5 Apr 2024 16:41:52 -0700 Subject: [PATCH] Fix narrowing unions in tuple pattern matches Allows sequence pattern matching of tuple types with union members to use an unmatched pattern to narrow the possible type of union member to the type that was not matched in the sequence pattern. For example given a type of `tuple[int, int | None]` a pattern match like: ``` match tuple_type: case a, None: return case t: reveal_type(t) # narrows tuple type to tuple[int, int] return ``` The case ..., None sequence pattern match can now narrow the type in further match statements to rule out the None side of the union. This is implemented by moving the original implementation of tuple narrowing in sequence pattern matching since its functionality should be to the special case where a tuple only has length of one. This implementation does not hold for tuples of length greater than one since it does not account all combinations of alternative types. This replace that implementation with a new one that builds the rest type by iterating over the potential rest type members preserving narrowed types if they are available and replacing any uninhabited types with the original type of tuple member since these matches are only exhaustive if all members of the tuple are matched. Fixes #14731 --- mypy/checkpattern.py | 48 +++++++++++++++-------- test-data/unit/check-python310.test | 60 +++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index a23be464b825..1ef2cdd5d6f0 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -302,24 +302,42 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_type: Type rest_type: Type = current_type if isinstance(current_type, TupleType) and unpack_index is None: - narrowed_inner_types = [] - inner_rest_types = [] - for inner_type, new_inner_type in zip(inner_types, new_inner_types): - (narrowed_inner_type, inner_rest_type) = ( - self.chk.conditional_types_with_intersection( - new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type - ) - ) - narrowed_inner_types.append(narrowed_inner_type) - inner_rest_types.append(inner_rest_type) - if all(not is_uninhabited(typ) for typ in narrowed_inner_types): - new_type = TupleType(narrowed_inner_types, current_type.partial_fallback) + if all(not is_uninhabited(typ) for typ in new_inner_types): + new_type = TupleType(new_inner_types, current_type.partial_fallback) else: new_type = UninhabitedType() - if all(is_uninhabited(typ) for typ in inner_rest_types): - # All subpatterns always match, so we can apply negative narrowing - rest_type = TupleType(rest_inner_types, current_type.partial_fallback) + if all(is_uninhabited(typ) for typ in rest_inner_types): + # If all types are uninhabited there is no other pattern that can + # match this tuple + rest_type = UninhabitedType() + elif any(is_uninhabited(typ) for typ in rest_inner_types): + # If at least one rest type is uninhabited the rest type can be narrowed + narrowed_types: list[Type] = [] + for inner_type, rest_type in zip(inner_types, rest_inner_types): + # if the narrowed rest type is Uninhabited that means that that the next + # pattern could match any of the original inner types of the tuple. + if is_uninhabited(rest_type): + narrowed_types.append(inner_type) + else: + narrowed_types.append(rest_type) + rest_type = TupleType(narrowed_types, current_type.partial_fallback) + elif len(rest_inner_types) == 1: + # Otherwise we need can apply negative narrowing if the alternative type + # is totally disjoint from the pattern narrowed type. And there is only + # one field in the tuple. + narrowed_inner_types = [] + inner_rest_types = [] + for inner_type, new_inner_type in zip(inner_types, new_inner_types): + (narrowed_inner_type, inner_rest_type) = ( + self.chk.conditional_types_with_intersection( + new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type + ) + ) + narrowed_inner_types.append(narrowed_inner_type) + inner_rest_types.append(inner_rest_type) + if all(is_uninhabited(typ) for typ in inner_rest_types): + rest_type = TupleType(rest_inner_types, current_type.partial_fallback) elif isinstance(current_type, TupleType): # For variadic tuples it is too tricky to match individual items like for fixed # tuples, so we instead try to narrow the entire type. diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 2b56d2db07a9..e50618b65285 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -407,6 +407,66 @@ match m, (n, o): pass [builtins fixtures/tuple.pyi] +[case testPartialMatchTuplePatternNarrowsUnion] +from typing import Tuple, Union + +m: Tuple[Union[int, None], Union[int, None]] + +match m: + case a, None: + reveal_type(a) # N: Revealed type is "Union[builtins.int, None]" + reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], None]" + case None, b: + reveal_type(b) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "Tuple[None, builtins.int]" + case t: + reveal_type(t) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testPartialMatchTuplePatternNarrowsUnionWithFullMatch] +from typing import Tuple, Union + +m: Tuple[Union[int, None], Union[int, None]] + +match m: + case None, None: + reveal_type(m) # N: Revealed type is "Tuple[None, None]" + case a, None: + reveal_type(a) # N: Revealed type is "Union[builtins.int, None]" + reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], None]" + case None, b: + reveal_type(b) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "Tuple[None, builtins.int]" + case t: + reveal_type(t) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testPartialMatchTuplePatternNarrowsUnionWithUninhibitedMatch] +from typing import Tuple, Union + +m: Tuple[Union[int, None], Union[int, None]] + +match m: + case a, b: + reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], Union[builtins.int, None]]" + case t: + reveal_type(t) +[builtins fixtures/tuple.pyi] + +[case testMatchTuplePatterNarrowsWithOrMatch] +from typing import Tuple, Union + +m: Tuple[Union[int, None], Union[int, None]] + +match m: + case (None, _) | (_, None): + reveal_type(m) # N: Revealed type is "Union[Tuple[None, Union[builtins.int, None]], Tuple[builtins.int, None]]" + case a, b: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + + -- Mapping Pattern -- [case testMatchMappingPatternCaptures]