diff --git a/mypy/binder.py b/mypy/binder.py index c1b6862c9e6d..394c5ccf987c 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -5,7 +5,7 @@ from typing_extensions import DefaultDict from mypy.types import ( - Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type + Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, TypeGuardType, get_proper_type ) from mypy.subtypes import is_subtype from mypy.join import join_simple @@ -202,7 +202,9 @@ def update_from_options(self, frames: List[Frame]) -> bool: else: for other in resulting_values[1:]: assert other is not None - type = join_simple(self.declarations[key], type, other) + # Ignore the error about using get_proper_type(). + if not isinstance(other, TypeGuardType): # type: ignore[misc] + type = join_simple(self.declarations[key], type, other) if current_value is None or not is_same_type(type, current_value): self._put(key, type) changed = True diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index fa340cb04044..c19e3c812e29 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -82,6 +82,7 @@ def is_str_list(a: List[object]) -> TypeGuard[List[str]]: pass def main(a: List[object]): if is_str_list(a): reveal_type(a) # N: Revealed type is "builtins.list[builtins.str]" + reveal_type(a) # N: Revealed type is "builtins.list[builtins.object]" [builtins fixtures/tuple.pyi] [case testTypeGuardUnionIn] @@ -91,6 +92,7 @@ def is_foo(a: Union[int, str]) -> TypeGuard[str]: pass def main(a: Union[str, int]) -> None: if is_foo(a): reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(a) # N: Revealed type is "Union[builtins.str, builtins.int]" [builtins fixtures/tuple.pyi] [case testTypeGuardUnionOut] @@ -315,3 +317,50 @@ def coverage(obj: Any) -> bool: return True return False [builtins fixtures/classmethod.pyi] + +[case testAssignToTypeGuardedVariable1] +from typing_extensions import TypeGuard + +class A: pass +class B(A): pass + +def guard(a: A) -> TypeGuard[B]: + pass + +a = A() +if not guard(a): + a = A() +[builtins fixtures/tuple.pyi] + +[case testAssignToTypeGuardedVariable2] +from typing_extensions import TypeGuard + +class A: pass +class B: pass + +def guard(a: A) -> TypeGuard[B]: + pass + +a = A() +if not guard(a): + a = A() +[builtins fixtures/tuple.pyi] + +[case testAssignToTypeGuardedVariable3] +from typing_extensions import TypeGuard + +class A: pass +class B: pass + +def guard(a: A) -> TypeGuard[B]: + pass + +a = A() +if guard(a): + reveal_type(a) # N: Revealed type is "__main__.B" + a = B() # E: Incompatible types in assignment (expression has type "B", variable has type "A") + reveal_type(a) # N: Revealed type is "__main__.B" + a = A() + reveal_type(a) # N: Revealed type is "__main__.A" +reveal_type(a) # N: Revealed type is "__main__.A" +[builtins fixtures/tuple.pyi]