Skip to content

Commit

Permalink
make typeguard work on python3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
deathowl committed Oct 7, 2020
1 parent 55fbcb5 commit 3e36538
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
12 changes: 11 additions & 1 deletion tests/test_typeguard.py
Expand Up @@ -10,7 +10,17 @@
Container, Generic, BinaryIO, TextIO, Generator, Iterator, AbstractSet, AnyStr, Type)

import pytest
from typing_extensions import NoReturn, Protocol, Literal, TypedDict, runtime_checkable
try:
from typing_extensions import NoReturn, Protocol, Literal, TypedDict, runtime_checkable
except ImportError:
try:
from typing import NoReturn, Protocol, Literal, TypedDict, runtime_checkable
except ImportError:
NoReturn = None
Protocol = None
Literal = None
TypedDict = None
runtime_checkable = None

from typeguard import (
typechecked, check_argument_types, qualified_name, TypeChecker, TypeWarning, function_name,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
@@ -1,6 +1,6 @@
[tox]
minversion = 3.3.0
envlist = pypy3, py35, py36, py37, py38, flake8
envlist = pypy3, py35, py36, py37, py38, py39, flake8
skip_missing_interpreters = true
isolated_build = true

Expand Down
42 changes: 31 additions & 11 deletions typeguard/__init__.py
Expand Up @@ -173,7 +173,11 @@ def __init__(self, func: Callable, frame_locals: Optional[Dict[str, Any]] = None

def resolve_forwardref(maybe_ref, memo: _TypeCheckMemo):
if isinstance(maybe_ref, ForwardRef):
return evaluate_forwardref(maybe_ref, memo.globals, memo.locals)
if sys.version_info < (3, 9, 0):
return evaluate_forwardref(maybe_ref, memo.globals, memo.locals)
else:
return evaluate_forwardref(maybe_ref, memo.globals, memo.locals, frozenset())

else:
return maybe_ref

Expand Down Expand Up @@ -247,7 +251,7 @@ def check_callable(argname: str, value, expected_type, memo: _TypeCheckMemo) ->
if not callable(value):
raise TypeError('{} must be a callable'.format(argname))

if expected_type.__args__:
if hasattr(expected_type, "__args__") and expected_type.__args__:
try:
signature = inspect.signature(value)
except (TypeError, ValueError):
Expand Down Expand Up @@ -297,7 +301,8 @@ def check_dict(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None
format(argname, qualified_name(value)))

if expected_type is not dict:
if expected_type.__args__ not in (None, expected_type.__parameters__):
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
key_type, value_type = expected_type.__args__
if key_type is not Any or value_type is not Any:
for k, v in value.items():
Expand Down Expand Up @@ -332,7 +337,8 @@ def check_list(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None
format(argname, qualified_name(value)))

if expected_type is not list:
if expected_type.__args__ not in (None, expected_type.__parameters__):
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for i, v in enumerate(value):
Expand All @@ -344,7 +350,8 @@ def check_sequence(argname: str, value, expected_type, memo: _TypeCheckMemo) ->
raise TypeError('type of {} must be a sequence; got {} instead'.
format(argname, qualified_name(value)))

if expected_type.__args__ not in (None, expected_type.__parameters__):
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for i, v in enumerate(value):
Expand All @@ -357,7 +364,8 @@ def check_set(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
format(argname, qualified_name(value)))

if expected_type is not set:
if expected_type.__args__ not in (None, expected_type.__parameters__):
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for v in value:
Expand All @@ -366,12 +374,22 @@ def check_set(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:

def check_tuple(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
# Specialized check for NamedTuples
if hasattr(expected_type, '_field_types'):
is_named_tuple = False
if sys.version_info < (3, 8, 0):
is_named_tuple = hasattr(expected_type, '_field_types') # deprecated since python 3.8
else:
is_named_tuple = hasattr(expected_type, '__annotations__')

if is_named_tuple:
if not isinstance(value, expected_type):
raise TypeError('type of {} must be a named tuple of type {}; got {} instead'.
format(argname, qualified_name(expected_type), qualified_name(value)))
if sys.version_info < (3, 8, 0):
field_types = expected_type._field_types
else:
field_types = expected_type.__annotations__

for name, field_type in expected_type._field_types.items():
for name, field_type in field_types.items():
check_type('{}.{}'.format(argname, name), getattr(value, name), field_type, memo)

return
Expand Down Expand Up @@ -436,7 +454,8 @@ def check_class(argname: str, value, expected_type, memo: _TypeCheckMemo) -> Non
if expected_type is Type:
return

expected_class = expected_type.__args__[0] if expected_type.__args__ else None
expected_class = expected_type.__args__[0] if hasattr(
expected_type, "__args__") and expected_type.__args__ else None
if expected_class:
if expected_class is Any:
return
Expand Down Expand Up @@ -728,10 +747,11 @@ def check_argument_types(memo: Optional[_CallMemo] = None) -> bool:

class TypeCheckedGenerator:
def __init__(self, wrapped: Generator, memo: _CallMemo):
rtype_args = memo.type_hints['return'].__args__
rtype_args = memo.type_hints['return'].__args__ if hasattr(
memo.type_hints['return'], "__args__") else []
self.__wrapped = wrapped
self.__memo = memo
self.__yield_type = rtype_args[0]
self.__yield_type = rtype_args[0] if rtype_args else Any
self.__send_type = rtype_args[1] if len(rtype_args) > 1 else Any
self.__return_type = rtype_args[2] if len(rtype_args) > 2 else Any
self.__initialized = False
Expand Down

0 comments on commit 3e36538

Please sign in to comment.