Skip to content

Commit

Permalink
Fixes generic inference in functions with TypeGuard (#11797)
Browse files Browse the repository at this point in the history
Fixes #11780, fixes #11428
  • Loading branch information
sobolevn committed May 8, 2022
1 parent 49d5cc9 commit fb11c98
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
7 changes: 7 additions & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,18 @@ def apply_generic_arguments(
# Apply arguments to argument types.
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None

# The callable may retain some type vars if only some were applied.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]

return callable.copy_modified(
arg_types=arg_types,
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
)
22 changes: 13 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,6 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
ret_type=self.object_type(),
fallback=self.named_type('builtins.function'))
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
if (isinstance(e.callee, RefExpr)
and isinstance(callee_type, CallableType)
and callee_type.type_guard is not None):
# Cache it for find_isinstance_check()
e.callee.type_guard = callee_type.type_guard
if (self.chk.options.disallow_untyped_calls and
self.chk.in_checked_function() and
isinstance(callee_type, CallableType)
Expand Down Expand Up @@ -886,10 +881,19 @@ def check_call_expr_with_callee_type(self,
# Unions are special-cased to allow plugins to act on each item in the union.
elif member is not None and isinstance(object_type, UnionType):
return self.check_union_call_expr(e, object_type, member)
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee,
callable_name=callable_name,
object_type=object_type)[0]
ret_type, callee_type = self.check_call(
callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee,
callable_name=callable_name,
object_type=object_type,
)
proper_callee = get_proper_type(callee_type)
if (isinstance(e.callee, RefExpr)
and isinstance(proper_callee, CallableType)
and proper_callee.type_guard is not None):
# Cache it for find_isinstance_check()
e.callee.type_guard = proper_callee.type_guard
return ret_type

def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
""""Type check calling a member expression where the base type is a union."""
Expand Down
50 changes: 50 additions & 0 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,53 @@ accepts_typeguard(with_typeguard_a) # E: Argument 1 to "accepts_typeguard" has
accepts_typeguard(with_typeguard_b)
accepts_typeguard(with_typeguard_c)
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithIdentityGeneric]
from typing import TypeVar
from typing_extensions import TypeGuard

_T = TypeVar("_T")

def identity(val: _T) -> TypeGuard[_T]:
pass

def func1(name: _T):
reveal_type(name) # N: Revealed type is "_T`-1"
if identity(name):
reveal_type(name) # N: Revealed type is "_T`-1"

def func2(name: str):
reveal_type(name) # N: Revealed type is "builtins.str"
if identity(name):
reveal_type(name) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithGenericInstance]
from typing import TypeVar, List
from typing_extensions import TypeGuard

_T = TypeVar("_T")

def is_list_of_str(val: _T) -> TypeGuard[List[_T]]:
pass

def func(name: str):
reveal_type(name) # N: Revealed type is "builtins.str"
if is_list_of_str(name):
reveal_type(name) # N: Revealed type is "builtins.list[builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithTupleGeneric]
from typing import TypeVar, Tuple
from typing_extensions import TypeGuard

_T = TypeVar("_T")

def is_two_element_tuple(val: Tuple[_T, ...]) -> TypeGuard[Tuple[_T, _T]]:
pass

def func(names: Tuple[str, ...]):
reveal_type(names) # N: Revealed type is "builtins.tuple[builtins.str, ...]"
if is_two_element_tuple(names):
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
[builtins fixtures/tuple.pyi]

0 comments on commit fb11c98

Please sign in to comment.