Skip to content

Commit

Permalink
Support TypeGuard (PEP 647) (#9865)
Browse files Browse the repository at this point in the history
PEP 647 is still in draft mode, but it is likely to be accepted, and this helps solve some real issues.
  • Loading branch information
gvanrossum committed Jan 18, 2021
1 parent 734e4ad commit fffbe88
Show file tree
Hide file tree
Showing 13 changed files with 408 additions and 9 deletions.
11 changes: 10 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType)
get_proper_types, is_literal_type, TypeAliasType, TypeGuardType)
from mypy.sametypes import is_same_type
from mypy.messages import (
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
Expand Down Expand Up @@ -3957,6 +3957,7 @@ def find_isinstance_check(self, node: Expression
) -> Tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Also includes TypeGuard functions.
Return value is a map of variables to their types if the condition
is true and a map of variables to their types if the condition is false.
Expand Down Expand Up @@ -4001,6 +4002,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE:
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
self.fail("Type guard requires positional argument", node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
return {expr: TypeGuardType(node.callee.type_guard)}, {}
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
Expand Down
11 changes: 10 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
make_optional_type,
)
from mypy.types import (
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarDef,
Type, AnyType, CallableType, Overloaded, NoneType, TypeGuardType, TypeVarDef,
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
is_named_instance, FunctionLike,
Expand Down Expand Up @@ -317,6 +317,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
ret_type=self.object_type(),
fallback=self.named_type('builtins.function'))
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
if (isinstance(e.callee, RefExpr)
and isinstance(callee_type, CallableType)
and callee_type.type_guard is not None):
# Cache it for find_isinstance_check()
e.callee.type_guard = callee_type.type_guard
if (self.chk.options.disallow_untyped_calls and
self.chk.in_checked_function() and
isinstance(callee_type, CallableType)
Expand Down Expand Up @@ -4163,6 +4168,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type,
"""
if literal(expr) >= LITERAL_TYPE:
restriction = self.chk.binder.get(expr)
# Ignore the error about using get_proper_type().
if isinstance(restriction, TypeGuardType): # type: ignore[misc]
# A type guard forces the new type even if it doesn't overlap the old.
return restriction.type_guard
# If the current node is deferred, some variables may get Any types that they
# otherwise wouldn't have. We don't want to narrow down these since it may
# produce invalid inferred Optional[Any] types, at least.
Expand Down
7 changes: 6 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
for t, a in zip(template.arg_types, cactual.arg_types):
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard
res.extend(infer_constraints(template_ret_type, cactual_ret_type,
self.direction))
return res
elif isinstance(self.actual, AnyType):
Expand Down
4 changes: 3 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:

def visit_callable_type(self, t: CallableType) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types),
ret_type=t.ret_type.accept(self))
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self)
if t.type_guard is not None else None))

def visit_overloaded(self, t: Overloaded) -> Type:
items = [] # type: List[CallableType]
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
for arg in ct.bound_args:
if arg:
arg.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items():
Expand Down
5 changes: 4 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
class RefExpr(Expression):
"""Abstract base class for name-like constructs"""

__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue')
__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue',
'type_guard')

def __init__(self) -> None:
super().__init__()
Expand All @@ -1467,6 +1468,8 @@ def __init__(self) -> None:
self.is_inferred_def = False
# Is this expression appears as an rvalue of a valid type alias definition?
self.is_alias_rvalue = False
# Cache type guard from callable_type.type_guard
self.type_guard = None # type: Optional[mypy.types.Type]


class NameExpr(RefExpr):
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
'check-annotated.test',
'check-parameter-specification.test',
'check-generic-alias.test',
'check-typeguard.test',
]

# Tests that use Python 3.8-only AST features (like expression-scoped ignores):
Expand Down
24 changes: 23 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt
" and at least one annotation", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
elif self.anal_type_guard_arg(t, fullname) is not None:
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
return self.named_type('builtins.bool')
return None

def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType:
Expand Down Expand Up @@ -524,15 +527,34 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
variables = t.variables
else:
variables = self.bind_function_type_variables(t, t)
special = self.anal_type_guard(t.ret_type)
ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested),
ret_type=self.anal_type(t.ret_type, nested=nested),
# If the fallback isn't filled in yet,
# its type will be the falsey FakeInfo
fallback=(t.fallback if t.fallback.type
else self.named_type('builtins.function')),
variables=self.anal_var_defs(variables))
variables=self.anal_var_defs(variables),
type_guard=special,
)
return ret

def anal_type_guard(self, t: Type) -> Optional[Type]:
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
return self.anal_type_guard_arg(t, sym.node.fullname)
# TODO: What if it's an Instance? Then use t.type.fullname?
return None

def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]:
if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'):
if len(t.args) != 1:
self.fail("TypeGuard must have exactly one type argument", t)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
return None

def visit_overloaded(self, t: Overloaded) -> Type:
# Overloaded types are manually constructed in semanal.py by analyzing the
# AST and combining together the Callable types this visitor converts.
Expand Down
30 changes: 27 additions & 3 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ def copy_modified(self, *,
self.line, self.column)


class TypeGuardType(Type):
"""Only used by find_instance_check() etc."""
def __init__(self, type_guard: Type):
super().__init__(line=type_guard.line, column=type_guard.column)
self.type_guard = type_guard

def __repr__(self) -> str:
return "TypeGuard({})".format(self.type_guard)


class ProperType(Type):
"""Not a type alias.
Expand Down Expand Up @@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
# tools that consume mypy ASTs
'def_extras', # Information about original definition we want to serialize.
# This is used for more detailed error messages.
'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case).
)

def __init__(self,
Expand All @@ -1024,6 +1035,7 @@ def __init__(self,
from_type_type: bool = False,
bound_args: Sequence[Optional[Type]] = (),
def_extras: Optional[Dict[str, Any]] = None,
type_guard: Optional[Type] = None,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1058,6 +1070,7 @@ def __init__(self,
not definition.is_static else None}
else:
self.def_extras = {}
self.type_guard = type_guard

def copy_modified(self,
arg_types: Bogus[Sequence[Type]] = _dummy,
Expand All @@ -1075,7 +1088,9 @@ def copy_modified(self,
special_sig: Bogus[Optional[str]] = _dummy,
from_type_type: Bogus[bool] = _dummy,
bound_args: Bogus[List[Optional[Type]]] = _dummy,
def_extras: Bogus[Dict[str, Any]] = _dummy) -> 'CallableType':
def_extras: Bogus[Dict[str, Any]] = _dummy,
type_guard: Bogus[Optional[Type]] = _dummy,
) -> 'CallableType':
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds,
Expand All @@ -1094,6 +1109,7 @@ def copy_modified(self,
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras),
type_guard=type_guard if type_guard is not _dummy else self.type_guard,
)

def var_arg(self) -> Optional[FormalArgument]:
Expand Down Expand Up @@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
def serialize(self) -> JsonDict:
# TODO: As an optimization, leave out everything related to
# generic functions for non-generic functions.
assert (self.type_guard is None
or isinstance(get_proper_type(self.type_guard), Instance)), str(self.type_guard)
return {'.class': 'CallableType',
'arg_types': [t.serialize() for t in self.arg_types],
'arg_kinds': self.arg_kinds,
Expand All @@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
'bound_args': [(None if t is None else t.serialize())
for t in self.bound_args],
'def_extras': dict(self.def_extras),
'type_guard': self.type_guard.serialize() if self.type_guard is not None else None,
}

@classmethod
Expand All @@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
implicit=data['implicit'],
bound_args=[(None if t is None else deserialize_type(t))
for t in data['bound_args']],
def_extras=data['def_extras']
def_extras=data['def_extras'],
type_guard=(deserialize_type(data['type_guard'])
if data['type_guard'] is not None else None),
)


Expand Down Expand Up @@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
s = '({})'.format(s)

if not isinstance(get_proper_type(t.ret_type), NoneType):
s += ' -> {}'.format(t.ret_type.accept(self))
if t.type_guard is not None:
s += ' -> TypeGuard[{}]'.format(t.type_guard.accept(self))
else:
s += ' -> {}'.format(t.ret_type.accept(self))

if t.variables:
vs = []
Expand Down
9 changes: 9 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,12 @@ def func() -> None:
class Foo:
def __init__(self) -> None:
self.x = 123

[case testWalrusTypeGuard]
from typing_extensions import TypeGuard
def is_float(a: object) -> TypeGuard[float]: pass
def main(a: object) -> None:
if is_float(x := a):
reveal_type(x) # N: Revealed type is 'builtins.float'
reveal_type(a) # N: Revealed type is 'builtins.object'
[builtins fixtures/tuple.pyi]
15 changes: 15 additions & 0 deletions test-data/unit/check-serialize.test
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,21 @@ def f(x: int) -> int: pass
tmp/a.py:2: note: Revealed type is 'builtins.str'
tmp/a.py:3: error: Unexpected keyword argument "x" for "f"

[case testSerializeTypeGuardFunction]
import a
[file a.py]
import b
[file a.py.2]
import b
reveal_type(b.guard(''))
reveal_type(b.guard)
[file b.py]
from typing_extensions import TypeGuard
def guard(a: object) -> TypeGuard[str]: pass
[builtins fixtures/tuple.pyi]
[out2]
tmp/a.py:2: note: Revealed type is 'builtins.bool'
tmp/a.py:3: note: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.str]'
--
-- Classes
--
Expand Down

0 comments on commit fffbe88

Please sign in to comment.