From 05f38571f1b7af46dbadaeba0f74626d3a9e07c9 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Mon, 13 Sep 2021 23:42:05 -0700 Subject: [PATCH] Fix type guard crashes (#11061) Fixes #11007, fixes #10899, fixes #10647 Since the initial implementation of TypeGuard, there have been several fixes quickly applied to make mypy not crash on various TypeGuard things. This includes #10496, #10683 and #11015. We'll discuss how this PR relates to each of these three changes. In particular, #10496 seems incorrect. As A5rocks discusses in #10899 , it introduces confusion between a type guarded variable and a TypeGuard[T]. This PR basically walks back that change entirely and renames TypeGuardType to TypeGuardedType to reduce that possible confusion. Now, we still have the issue that TypeGuardedTypes are getting everywhere and causing unhappiness. I see two high level solutions to this: a) Make TypeGuardedType a proper type, then delegate to the wrapped type in a bunch of type visitors and arbitrary amounts of other places where multiple types interact, and hope that we got all of them, b) Make TypeGuardedType as an improper type (as it was in the original implementation)! Similar to TypeAliasType, it's just a wrapper for another type, so we unwrap it in get_proper_type. This is the approach this PR takes. This might feel controversial, but I think it could be the better option. It also means that if we type check we won't get type guard crashes. #10683 is basically "remove call that leads to crash from the stacktrace". I think the join here (that ends up being with the wrapped type of the TypeGuardedType) is actually fine: if it's different, it tells us that the type changed, which is what we want to know. So seems fine to remove the special casing. Finally, #11015. This is the other contentious part of this PR. I liked the idea of moving the core "type guard overrides narrowing" idea into meet.py, so I kept that. But my changes ended up regressing a reveal_type testTypeGuardNestedRestrictionAny test that was added. But it's not really clear to me how that worked or really, what it tested. I tried writing a simpler version of what I thought the test was meant to test (this is testTypeGuardMultipleCondition added in this PR), but that fails on master. Anyway, this should at least fix the type guard crashes that have been coming up. --- mypy/binder.py | 16 ++-------- mypy/checker.py | 4 +-- mypy/constraints.py | 5 +--- mypy/erasetype.py | 5 +--- mypy/expandtype.py | 5 +--- mypy/fixup.py | 5 +--- mypy/indirection.py | 3 -- mypy/join.py | 5 +--- mypy/meet.py | 27 ++++++++--------- mypy/sametypes.py | 8 +---- mypy/server/astdiff.py | 5 +--- mypy/server/astmerge.py | 5 +--- mypy/server/deps.py | 5 +--- mypy/subtypes.py | 13 +------- mypy/type_visitor.py | 12 +------- mypy/typeanal.py | 5 +--- mypy/types.py | 22 ++++++-------- mypy/typetraverser.py | 5 +--- test-data/unit/check-typeguard.test | 46 ++++++++++++++++++++++++++++- 19 files changed, 84 insertions(+), 117 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index de12d153f621..523367a7685c 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -5,7 +5,7 @@ from typing_extensions import DefaultDict from mypy.types import ( - Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, TypeGuardType, get_proper_type + Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type ) from mypy.subtypes import is_subtype from mypy.join import join_simple @@ -210,9 +210,7 @@ def update_from_options(self, frames: List[Frame]) -> bool: else: for other in resulting_values[1:]: assert other is not None - # Ignore the error about using get_proper_type(). - if not contains_type_guard(other): - type = join_simple(self.declarations[key], type, other) + type = join_simple(self.declarations[key], type, other) if current_value is None or not is_same_type(type, current_value): self._put(key, type) changed = True @@ -440,13 +438,3 @@ def get_declaration(expr: BindableExpression) -> Optional[Type]: if not isinstance(type, PartialType): return type return None - - -def contains_type_guard(other: Type) -> bool: - # Ignore the error about using get_proper_type(). - if isinstance(other, TypeGuardType): # type: ignore[misc] - return True - other = get_proper_type(other) - if isinstance(other, UnionType): - return any(contains_type_guard(item) for item in other.relevant_items()) - return False diff --git a/mypy/checker.py b/mypy/checker.py index 7c11ef939b76..af223506ecd3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -36,7 +36,7 @@ UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, is_named_instance, union_items, TypeQuery, LiteralType, is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type, TypeAliasType, TypeGuardType) + get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType) from mypy.sametypes import is_same_type from mypy.messages import ( MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq, @@ -4265,7 +4265,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # considered "always right" (i.e. even if the types are not overlapping). # Also note that a care must be taken to unwrap this back at read places # where we use this to narrow down declared type. - return {expr: TypeGuardType(node.callee.type_guard)}, {} + return {expr: TypeGuardedType(node.callee.type_guard)}, {} elif isinstance(node, ComparisonExpr): # Step 1: Obtain the types of each operand and whether or not we can # narrow their types. (For example, we shouldn't try narrowing the diff --git a/mypy/constraints.py b/mypy/constraints.py index f85353cff1c0..d8dad95a3430 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,7 +7,7 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, get_proper_type, TypeAliasType, TypeGuardType + ProperType, get_proper_type, TypeAliasType ) from mypy.maptype import map_instance_to_supertype import mypy.subtypes @@ -544,9 +544,6 @@ def visit_union_type(self, template: UnionType) -> List[Constraint]: def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]: assert False, "This should be never called, got {}".format(template) - def visit_type_guard_type(self, template: TypeGuardType) -> List[Constraint]: - assert False, "This should be never called, got {}".format(template) - def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]: res: List[Constraint] = [] for t in types: diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 70b7c3b6de32..7a56eceacf5f 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -4,7 +4,7 @@ Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type, TypeAliasType, TypeGuardType + get_proper_type, TypeAliasType ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -90,9 +90,6 @@ def visit_union_type(self, t: UnionType) -> ProperType: from mypy.typeops import make_simplified_union return make_simplified_union(erased_items) - def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: - return TypeGuardType(t.type_guard.accept(self)) - def visit_type_type(self, t: TypeType) -> ProperType: return TypeType.make_normalized(t.item.accept(self), line=t.line) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index abb5a0e6836b..8b7d434d1e31 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -1,7 +1,7 @@ from typing import Dict, Iterable, List, TypeVar, Mapping, cast from mypy.types import ( - Type, Instance, CallableType, TypeGuardType, TypeVisitor, UnboundType, AnyType, + Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType, @@ -129,9 +129,6 @@ def visit_union_type(self, t: UnionType) -> Type: from mypy.typeops import make_simplified_union # asdf return make_simplified_union(self.expand_types(t.items), t.line, t.column) - def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: - return TypeGuardType(t.type_guard.accept(self)) - def visit_partial_type(self, t: PartialType) -> Type: return t diff --git a/mypy/fixup.py b/mypy/fixup.py index 2ad9fe5e8fcb..b8c494292d28 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -9,7 +9,7 @@ TypeVarExpr, ClassDef, Block, TypeAlias, ) from mypy.types import ( - CallableType, Instance, Overloaded, TupleType, TypeGuardType, TypedDictType, + CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny ) @@ -254,9 +254,6 @@ def visit_union_type(self, ut: UnionType) -> None: for it in ut.items: it.accept(self) - def visit_type_guard_type(self, t: TypeGuardType) -> None: - t.type_guard.accept(self) - def visit_void(self, o: Any) -> None: pass # Nothing to descend into. diff --git a/mypy/indirection.py b/mypy/indirection.py index 9d47d9af7889..96992285c90f 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -97,9 +97,6 @@ def visit_literal_type(self, t: types.LiteralType) -> Set[str]: def visit_union_type(self, t: types.UnionType) -> Set[str]: return self._visit(t.items) - def visit_type_guard_type(self, t: types.TypeGuardType) -> Set[str]: - return self._visit(t.type_guard) - def visit_partial_type(self, t: types.PartialType) -> Set[str]: return set() diff --git a/mypy/join.py b/mypy/join.py index edf512e1be11..e18bc6281fde 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -7,7 +7,7 @@ Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types, TypeAliasType, PlaceholderType, TypeGuardType + ProperType, get_proper_types, TypeAliasType, PlaceholderType ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -432,9 +432,6 @@ def visit_type_type(self, t: TypeType) -> ProperType: def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: assert False, "This should be never called, got {}".format(t) - def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: - assert False, "This should be never called, got {}".format(t) - def join(self, s: Type, t: Type) -> ProperType: return join_types(s, t) diff --git a/mypy/meet.py b/mypy/meet.py index f94ec4e65bf7..f89c1fc7b16f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -5,7 +5,7 @@ Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, - ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardType + ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type @@ -51,16 +51,16 @@ def meet_types(s: Type, t: Type) -> ProperType: def narrow_declared_type(declared: Type, narrowed: Type) -> Type: """Return the declared type narrowed down to another type.""" # TODO: check infinite recursion for aliases here. + if isinstance(narrowed, TypeGuardedType): # type: ignore[misc] + # A type guard forces the new type even if it doesn't overlap the old. + return narrowed.type_guard + declared = get_proper_type(declared) narrowed = get_proper_type(narrowed) if declared == narrowed: return declared - # Ignore the error about using get_proper_type(). - if isinstance(narrowed, TypeGuardType): - # A type guard forces the new type even if it doesn't overlap the old. - return narrowed.type_guard - elif isinstance(declared, UnionType): + if isinstance(declared, UnionType): return make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) elif not is_overlapping_types(declared, narrowed, @@ -146,6 +146,13 @@ def is_overlapping_types(left: Type, If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with TypeVars (in both strict-optional and non-strict-optional mode). """ + if ( + isinstance(left, TypeGuardedType) # type: ignore[misc] + or isinstance(right, TypeGuardedType) # type: ignore[misc] + ): + # A type guard forces the new type even if it doesn't overlap the old. + return True + left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: @@ -161,11 +168,6 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: if isinstance(left, PartialType) or isinstance(right, PartialType): assert False, "Unexpectedly encountered partial type" - # Ignore the error about using get_proper_type(). - if isinstance(left, TypeGuardType) or isinstance(right, TypeGuardType): - # A type guard forces the new type even if it doesn't overlap the old. - return True - # We should also never encounter these types, but it's possible a few # have snuck through due to unrelated bugs. For now, we handle these # in the same way we handle 'Any'. @@ -657,9 +659,6 @@ def visit_type_type(self, t: TypeType) -> ProperType: def visit_type_alias_type(self, t: TypeAliasType) -> ProperType: assert False, "This should be never called, got {}".format(t) - def visit_type_guard_type(self, t: TypeGuardType) -> ProperType: - assert False, "This should be never called, got {}".format(t) - def meet(self, s: Type, t: Type) -> ProperType: return meet_types(s, t) diff --git a/mypy/sametypes.py b/mypy/sametypes.py index d389b8fd2581..020bda775b59 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -1,7 +1,7 @@ from typing import Sequence from mypy.types import ( - Type, TypeGuardType, UnboundType, AnyType, NoneType, TupleType, TypedDictType, + Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, ProperType, get_proper_type, TypeAliasType) @@ -151,12 +151,6 @@ def visit_union_type(self, left: UnionType) -> bool: else: return False - def visit_type_guard_type(self, left: TypeGuardType) -> bool: - if isinstance(self.right, TypeGuardType): - return is_same_type(left.type_guard, self.right.type_guard) - else: - return False - def visit_overloaded(self, left: Overloaded) -> bool: if isinstance(self.right, Overloaded): return is_same_types(left.items, self.right.items) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index ac3f0d35aace..789c7c909911 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -57,7 +57,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' FuncBase, OverloadedFuncDef, FuncItem, MypyFile, UNBOUND_IMPORTED ) from mypy.types import ( - Type, TypeGuardType, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, + Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType ) @@ -335,9 +335,6 @@ def visit_union_type(self, typ: UnionType) -> SnapshotItem: normalized = tuple(sorted(items)) return ('UnionType', normalized) - def visit_type_guard_type(self, typ: TypeGuardType) -> SnapshotItem: - return ('TypeGuardType', snapshot_type(typ.type_guard)) - def visit_overloaded(self, typ: Overloaded) -> SnapshotItem: return ('Overloaded', snapshot_types(typ.items)) diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 107be94987cb..5d6a810264b6 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -59,7 +59,7 @@ Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarType, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, TypeAliasType, TypeGuardType + RawExpressionType, PartialType, PlaceholderType, TypeAliasType ) from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState @@ -389,9 +389,6 @@ def visit_erased_type(self, t: ErasedType) -> None: def visit_deleted_type(self, typ: DeletedType) -> None: pass - def visit_type_guard_type(self, typ: TypeGuardType) -> None: - raise RuntimeError - def visit_partial_type(self, typ: PartialType) -> None: raise RuntimeError diff --git a/mypy/server/deps.py b/mypy/server/deps.py index c0d504c080c1..f80673fdb7d4 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -99,7 +99,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType, TypeGuardType + TypeAliasType ) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import @@ -967,9 +967,6 @@ def visit_unbound_type(self, typ: UnboundType) -> List[str]: def visit_uninhabited_type(self, typ: UninhabitedType) -> List[str]: return [] - def visit_type_guard_type(self, typ: TypeGuardType) -> List[str]: - return typ.type_guard.accept(self) - def visit_union_type(self, typ: UnionType) -> List[str]: triggers = [] for item in typ.items: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 0afa29d00527..4ac99d3fa8e2 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -4,7 +4,7 @@ from typing_extensions import Final from mypy.types import ( - Type, AnyType, TypeGuardType, UnboundType, TypeVisitor, FormalArgument, NoneType, + Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType @@ -475,9 +475,6 @@ def visit_overloaded(self, left: Overloaded) -> bool: def visit_union_type(self, left: UnionType) -> bool: return all(self._is_subtype(item, self.orig_right) for item in left.items) - def visit_type_guard_type(self, left: TypeGuardType) -> bool: - raise RuntimeError("TypeGuard should not appear here") - def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. raise RuntimeError @@ -1377,14 +1374,6 @@ def visit_overloaded(self, left: Overloaded) -> bool: def visit_union_type(self, left: UnionType) -> bool: return all([self._is_proper_subtype(item, self.orig_right) for item in left.items]) - def visit_type_guard_type(self, left: TypeGuardType) -> bool: - if isinstance(self.right, TypeGuardType): - # TypeGuard[bool] is a subtype of TypeGuard[int] - return self._is_proper_subtype(left.type_guard, self.right.type_guard) - else: - # TypeGuards aren't a subtype of anything else for now (but see #10489) - return False - def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? return False diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 019709dc5e26..2b4ebffb93e0 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -19,7 +19,7 @@ T = TypeVar('T') from mypy.types import ( - Type, AnyType, CallableType, Overloaded, TupleType, TypeGuardType, TypedDictType, LiteralType, + Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, RawExpressionType, Instance, NoneType, TypeType, UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeType, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, @@ -103,10 +103,6 @@ def visit_type_type(self, t: TypeType) -> T: def visit_type_alias_type(self, t: TypeAliasType) -> T: pass - @abstractmethod - def visit_type_guard_type(self, t: TypeGuardType) -> T: - pass - @trait @mypyc_attr(allow_interpreted_subclasses=True) @@ -224,9 +220,6 @@ def visit_union_type(self, t: UnionType) -> Type: def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] - def visit_type_guard_type(self, t: TypeGuardType) -> Type: - return TypeGuardType(t.type_guard.accept(self)) - def translate_variables(self, variables: Sequence[TypeVarLikeType]) -> Sequence[TypeVarLikeType]: return variables @@ -326,9 +319,6 @@ def visit_star_type(self, t: StarType) -> T: def visit_union_type(self, t: UnionType) -> T: return self.query_types(t.items) - def visit_type_guard_type(self, t: TypeGuardType) -> T: - return t.type_guard.accept(self) - def visit_overloaded(self, t: Overloaded) -> T: return self.query_types(t.items) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 07dc704e42ea..6e2cb350f6e6 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -14,7 +14,7 @@ from mypy.types import ( Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType, CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarType, SyntheticTypeVisitor, - StarType, PartialType, EllipsisType, UninhabitedType, TypeType, TypeGuardType, TypeVarLikeType, + StarType, PartialType, EllipsisType, UninhabitedType, TypeType, TypeVarLikeType, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeType, ParamSpecType ) @@ -547,9 +547,6 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: ) return ret - def visit_type_guard_type(self, t: TypeGuardType) -> Type: - return t - def anal_type_guard(self, t: Type) -> Optional[Type]: if isinstance(t, UnboundType): sym = self.lookup_qualified(t.name, t) diff --git a/mypy/types.py b/mypy/types.py index e16ea858efe3..685d6af26db9 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -276,14 +276,7 @@ def copy_modified(self, *, self.line, self.column) -class ProperType(Type): - """Not a type alias. - - Every type except TypeAliasType must inherit from this type. - """ - - -class TypeGuardType(ProperType): +class TypeGuardedType(Type): """Only used by find_instance_check() etc.""" def __init__(self, type_guard: Type): super().__init__(line=type_guard.line, column=type_guard.column) @@ -292,8 +285,12 @@ def __init__(self, type_guard: Type): def __repr__(self) -> str: return "TypeGuard({})".format(self.type_guard) - def accept(self, visitor: 'TypeVisitor[T]') -> T: - return visitor.visit_type_guard_type(self) + +class ProperType(Type): + """Not a type alias. + + Every type except TypeAliasType must inherit from this type. + """ class TypeVarId: @@ -1947,6 +1944,8 @@ def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: """ if typ is None: return None + if isinstance(typ, TypeGuardedType): # type: ignore[misc] + typ = typ.type_guard while isinstance(typ, TypeAliasType): typ = typ._expand_once() assert isinstance(typ, ProperType), typ @@ -2146,9 +2145,6 @@ def visit_union_type(self, t: UnionType) -> str: s = self.list_str(t.items) return 'Union[{}]'.format(s) - def visit_type_guard_type(self, t: TypeGuardType) -> str: - return 'TypeGuard[{}]'.format(t.type_guard.accept(self)) - def visit_partial_type(self, t: PartialType) -> str: if t.type is None: return '' diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 174e97a93c31..3bebd3831971 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType, TypeGuardType + PlaceholderType, PartialType, RawExpressionType, TypeAliasType ) @@ -62,9 +62,6 @@ def visit_typeddict_type(self, t: TypedDictType) -> None: def visit_union_type(self, t: UnionType) -> None: self.traverse_types(t.items) - def visit_type_guard_type(self, t: TypeGuardType) -> None: - t.type_guard.accept(self) - def visit_overloaded(self, t: Overloaded) -> None: self.traverse_types(t.items) diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index c4f88ca3f018..fb26f0d3d537 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -401,7 +401,27 @@ def test(x: object) -> None: g(reveal_type(x)) # N: Revealed type is "Union[__main__.A, __main__.B]" [builtins fixtures/tuple.pyi] -[case testTypeGuardNestedRestrictionUnionIsInstance] +[case testTypeGuardComprehensionSubtype] +from typing import List +from typing_extensions import TypeGuard + +class Base: ... +class Foo(Base): ... +class Bar(Base): ... + +def is_foo(item: object) -> TypeGuard[Foo]: + return isinstance(item, Foo) + +def is_bar(item: object) -> TypeGuard[Bar]: + return isinstance(item, Bar) + +def foobar(items: List[object]): + a: List[Base] = [x for x in items if is_foo(x) or is_bar(x)] + b: List[Base] = [x for x in items if is_foo(x)] + c: List[Bar] = [x for x in items if is_foo(x)] # E: List comprehension has incompatible type List[Foo]; expected List[Bar] +[builtins fixtures/tuple.pyi] + +[case testTypeGuardNestedRestrictionUnionIsInstance-xfail] from typing_extensions import TypeGuard from typing import Any, List @@ -414,3 +434,27 @@ def test(x: List[object]) -> None: return g(reveal_type(x)) # N: Revealed type is "Union[builtins.list[builtins.str], __main__.]" [builtins fixtures/tuple.pyi] + +[case testTypeGuardMultipleCondition-xfail] +from typing_extensions import TypeGuard +from typing import Any, List + +class Foo: ... +class Bar: ... + +def is_foo(item: object) -> TypeGuard[Foo]: + return isinstance(item, Foo) + +def is_bar(item: object) -> TypeGuard[Bar]: + return isinstance(item, Bar) + +def foobar(x: object): + if not isinstance(x, Foo) or not isinstance(x, Bar): + return + reveal_type(x) # N: Revealed type is "__main__." + +def foobar_typeguard(x: object): + if not is_foo(x) or not is_bar(x): + return + reveal_type(x) # N: Revealed type is "__main__." +[builtins fixtures/tuple.pyi]