/
__init__.py
1110 lines (893 loc) · 43.7 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
__all__ = ('ForwardRefPolicy', 'TypeHintWarning', 'typechecked', 'check_return_type',
'check_argument_types', 'check_type', 'TypeWarning', 'TypeChecker')
import collections.abc
import gc
import inspect
import sys
import threading
from collections import OrderedDict
from enum import Enum
from functools import wraps, partial
from inspect import Parameter, isclass, isfunction, isgeneratorfunction
from io import TextIOBase, RawIOBase, IOBase, BufferedIOBase
from traceback import extract_stack, print_stack
from types import CodeType, FunctionType
from typing import (
Callable, Any, Union, Dict, List, TypeVar, Tuple, Set, Sequence, get_type_hints, TextIO,
Optional, IO, BinaryIO, Type, Generator, overload, Iterable, AsyncIterable, Iterator,
AsyncIterator, AbstractSet)
from unittest.mock import Mock
from warnings import warn
from weakref import WeakKeyDictionary, WeakValueDictionary
# Python 3.8+
try:
from typing_extensions import Literal
except ImportError:
try:
from typing import Literal
except ImportError:
Literal = None
# Python 3.5.4+ / 3.6.2+
try:
from typing_extensions import NoReturn
except ImportError:
try:
from typing import NoReturn
except ImportError:
NoReturn = None
# Python 3.6+
try:
from inspect import isasyncgenfunction, isasyncgen
from typing import AsyncGenerator
except ImportError:
AsyncGenerator = None
def isasyncgen(obj):
return False
def isasyncgenfunction(func):
return False
# Python 3.8+
try:
from typing import ForwardRef
evaluate_forwardref = ForwardRef._evaluate
except ImportError:
from typing import _ForwardRef as ForwardRef
evaluate_forwardref = ForwardRef._eval_type
_type_hints_map = WeakKeyDictionary() # type: Dict[FunctionType, Dict[str, Any]]
_functions_map = WeakValueDictionary() # type: Dict[CodeType, FunctionType]
_missing = object()
T_CallableOrType = TypeVar('T_CallableOrType', Callable, Type[Any])
class ForwardRefPolicy(Enum):
"""Defines how unresolved forward references are handled."""
ERROR = 1 #: propagate the :exc:`NameError` from :func:`~typing.get_type_hints`
WARN = 2 #: remove the annotation and emit a TypeHintWarning
#: replace the annotation with the argument's class if the qualified name matches, else remove
#: the annotation
GUESS = 3
class TypeHintWarning(UserWarning):
"""
A warning that is emitted when a type hint in string form could not be resolved to an actual
type.
"""
class _TypeCheckMemo:
__slots__ = 'globals', 'locals', 'typevars'
def __init__(self, globals: Dict[str, Any], locals: Dict[str, Any]):
self.globals = globals
self.locals = locals
self.typevars = {} # type: Dict[Any, type]
class _CallMemo(_TypeCheckMemo):
__slots__ = 'func', 'func_name', 'arguments', 'is_generator', 'type_hints'
def __init__(self, func: Callable, frame_locals: Optional[Dict[str, Any]] = None,
args: tuple = None, kwargs: Dict[str, Any] = None,
forward_refs_policy=ForwardRefPolicy.ERROR):
super().__init__(func.__globals__, frame_locals)
self.func = func
self.func_name = function_name(func)
self.is_generator = isgeneratorfunction(func)
signature = inspect.signature(func)
if args is not None and kwargs is not None:
self.arguments = signature.bind(*args, **kwargs).arguments
else:
assert frame_locals is not None, 'frame must be specified if args or kwargs is None'
self.arguments = frame_locals
self.type_hints = _type_hints_map.get(func)
if self.type_hints is None:
while True:
if sys.version_info < (3, 5, 3):
frame_locals = dict(frame_locals)
try:
hints = get_type_hints(func, localns=frame_locals)
except NameError as exc:
if forward_refs_policy is ForwardRefPolicy.ERROR:
raise
typename = str(exc).split("'", 2)[1]
for param in signature.parameters.values():
if param.annotation == typename:
break
else:
raise
func_name = function_name(func)
if forward_refs_policy is ForwardRefPolicy.GUESS:
if param.name in self.arguments:
argtype = self.arguments[param.name].__class__
if param.annotation == argtype.__qualname__:
func.__annotations__[param.name] = argtype
msg = ('Replaced forward declaration {!r} in {} with {!r}'
.format(param.annotation, func_name, argtype))
warn(TypeHintWarning(msg))
continue
msg = 'Could not resolve type hint {!r} on {}: {}'.format(
param.annotation, function_name(func), exc)
warn(TypeHintWarning(msg))
del func.__annotations__[param.name]
else:
break
self.type_hints = OrderedDict()
for name, parameter in signature.parameters.items():
if name in hints:
annotated_type = hints[name]
# PEP 428 discourages it by MyPy does not complain
if parameter.default is None:
annotated_type = Optional[annotated_type]
if parameter.kind == Parameter.VAR_POSITIONAL:
self.type_hints[name] = Tuple[annotated_type, ...]
elif parameter.kind == Parameter.VAR_KEYWORD:
self.type_hints[name] = Dict[str, annotated_type]
else:
self.type_hints[name] = annotated_type
if 'return' in hints:
self.type_hints['return'] = hints['return']
_type_hints_map[func] = self.type_hints
def resolve_forwardref(maybe_ref, memo: _TypeCheckMemo):
if isinstance(maybe_ref, ForwardRef):
if sys.version_info < (3, 9, 0):
return evaluate_forwardref(maybe_ref, memo.globals, memo.locals)
else:
return evaluate_forwardref(maybe_ref, memo.globals, memo.locals, frozenset())
else:
return maybe_ref
def get_type_name(type_):
# typing.* types don't have a __name__ on Python 3.7+
return getattr(type_, '__name__', None) or getattr(type_, '_name', None) or str(type_)
def find_function(frame) -> Optional[Callable]:
"""
Return a function object from the garbage collector that matches the frame's code object.
This process is unreliable as several function objects could use the same code object.
Fortunately the likelihood of this happening with the combination of the function objects
having different type annotations is a very rare occurrence.
:param frame: a frame object
:return: a function object if one was found, ``None`` if not
"""
func = _functions_map.get(frame.f_code)
if func is None:
for obj in gc.get_referrers(frame.f_code):
if inspect.isfunction(obj):
if func is None:
# The first match was found
func = obj
else:
# A second match was found
return None
# Cache the result for future lookups
if func is not None:
_functions_map[frame.f_code] = func
else:
raise LookupError('target function not found')
return func
def qualified_name(obj) -> str:
"""
Return the qualified name (e.g. package.module.Type) for the given object.
Builtins and types from the :mod:`typing` package get special treatment by having the module
name stripped from the generated name.
"""
type_ = obj if inspect.isclass(obj) else type(obj)
module = type_.__module__
qualname = type_.__qualname__
return qualname if module in ('typing', 'builtins') else '{}.{}'.format(module, qualname)
def function_name(func: Callable) -> str:
"""
Return the qualified name of the given function.
Builtins and types from the :mod:`typing` package get special treatment by having the module
name stripped from the generated name.
"""
# For partial functions and objects with __call__ defined, __qualname__ does not exist
module = func.__module__
qualname = getattr(func, '__qualname__', repr(func))
return qualname if module == 'builtins' else '{}.{}'.format(module, qualname)
def check_callable(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not callable(value):
raise TypeError('{} must be a callable'.format(argname))
if getattr(expected_type, "__args__", None):
try:
signature = inspect.signature(value)
except (TypeError, ValueError):
return
if hasattr(expected_type, '__result__'):
# Python 3.5
argument_types = expected_type.__args__
check_args = argument_types is not Ellipsis
else:
# Python 3.6
argument_types = expected_type.__args__[:-1]
check_args = argument_types != (Ellipsis,)
if check_args:
# The callable must not have keyword-only arguments without defaults
unfulfilled_kwonlyargs = [
param.name for param in signature.parameters.values() if
param.kind == Parameter.KEYWORD_ONLY and param.default == Parameter.empty]
if unfulfilled_kwonlyargs:
raise TypeError(
'callable passed as {} has mandatory keyword-only arguments in its '
'declaration: {}'.format(argname, ', '.join(unfulfilled_kwonlyargs)))
num_mandatory_args = len([
param.name for param in signature.parameters.values()
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) and
param.default is Parameter.empty])
has_varargs = any(param for param in signature.parameters.values()
if param.kind == Parameter.VAR_POSITIONAL)
if num_mandatory_args > len(argument_types):
raise TypeError(
'callable passed as {} has too many arguments in its declaration; expected {} '
'but {} argument(s) declared'.format(argname, len(argument_types),
num_mandatory_args))
elif not has_varargs and num_mandatory_args < len(argument_types):
raise TypeError(
'callable passed as {} has too few arguments in its declaration; expected {} '
'but {} argument(s) declared'.format(argname, len(argument_types),
num_mandatory_args))
def check_dict(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not isinstance(value, dict):
raise TypeError('type of {} must be a dict; got {} instead'.
format(argname, qualified_name(value)))
if expected_type is not dict:
if (hasattr(expected_type, "__args__") and
expected_type.__args__ not in (None, expected_type.__parameters__)):
key_type, value_type = expected_type.__args__
if key_type is not Any or value_type is not Any:
for k, v in value.items():
check_type('keys of {}'.format(argname), k, key_type, memo)
check_type('{}[{!r}]'.format(argname, k), v, value_type, memo)
def check_typed_dict(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
expected_keys = frozenset(expected_type.__annotations__)
existing_keys = frozenset(value)
extra_keys = existing_keys - expected_keys
if extra_keys:
keys_formatted = ', '.join('"{}"'.format(key) for key in sorted(extra_keys))
raise TypeError('extra key(s) ({}) in {}'.format(keys_formatted, argname))
if expected_type.__total__:
missing_keys = expected_keys - existing_keys
if missing_keys:
keys_formatted = ', '.join('"{}"'.format(key) for key in sorted(missing_keys))
raise TypeError('required key(s) ({}) missing from {}'.format(keys_formatted, argname))
for key, argtype in get_type_hints(expected_type).items():
argvalue = value.get(key, _missing)
if argvalue is not _missing:
check_type('dict item "{}" for {}'.format(key, argname), argvalue, argtype, memo)
def check_list(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not isinstance(value, list):
raise TypeError('type of {} must be a list; got {} instead'.
format(argname, qualified_name(value)))
if expected_type is not list:
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for i, v in enumerate(value):
check_type('{}[{}]'.format(argname, i), v, value_type, memo)
def check_sequence(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not isinstance(value, collections.abc.Sequence):
raise TypeError('type of {} must be a sequence; got {} instead'.
format(argname, qualified_name(value)))
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for i, v in enumerate(value):
check_type('{}[{}]'.format(argname, i), v, value_type, memo)
def check_set(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not isinstance(value, AbstractSet):
raise TypeError('type of {} must be a set; got {} instead'.
format(argname, qualified_name(value)))
if expected_type is not set:
if hasattr(expected_type, "__args__") and expected_type.__args__ not in \
(None, expected_type.__parameters__):
value_type = expected_type.__args__[0]
if value_type is not Any:
for v in value:
check_type('elements of {}'.format(argname), v, value_type, memo)
def check_tuple(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
# Specialized check for NamedTuples
is_named_tuple = False
if sys.version_info < (3, 8, 0):
is_named_tuple = hasattr(expected_type, '_field_types') # deprecated since python 3.8
else:
is_named_tuple = hasattr(expected_type, '__annotations__')
if is_named_tuple:
if not isinstance(value, expected_type):
raise TypeError('type of {} must be a named tuple of type {}; got {} instead'.
format(argname, qualified_name(expected_type), qualified_name(value)))
if sys.version_info < (3, 8, 0):
field_types = expected_type._field_types
else:
field_types = expected_type.__annotations__
for name, field_type in field_types.items():
check_type('{}.{}'.format(argname, name), getattr(value, name), field_type, memo)
return
elif not isinstance(value, tuple):
raise TypeError('type of {} must be a tuple; got {} instead'.
format(argname, qualified_name(value)))
if getattr(expected_type, '__tuple_params__', None):
# Python 3.5
use_ellipsis = expected_type.__tuple_use_ellipsis__
tuple_params = expected_type.__tuple_params__
elif getattr(expected_type, '__args__', None):
# Python 3.6+
use_ellipsis = expected_type.__args__[-1] is Ellipsis
tuple_params = expected_type.__args__[:-1 if use_ellipsis else None]
else:
# Unparametrized Tuple or plain tuple
return
if use_ellipsis:
element_type = tuple_params[0]
for i, element in enumerate(value):
check_type('{}[{}]'.format(argname, i), element, element_type, memo)
elif tuple_params == ((),):
if value != ():
raise TypeError('{} is not an empty tuple but one was expected'.format(argname))
else:
if len(value) != len(tuple_params):
raise TypeError('{} has wrong number of elements (expected {}, got {} instead)'
.format(argname, len(tuple_params), len(value)))
for i, (element, element_type) in enumerate(zip(value, tuple_params)):
check_type('{}[{}]'.format(argname, i), element, element_type, memo)
def check_union(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if hasattr(expected_type, '__union_params__'):
# Python 3.5
union_params = expected_type.__union_params__
else:
# Python 3.6+
union_params = expected_type.__args__
for type_ in union_params:
try:
check_type(argname, value, type_, memo)
return
except TypeError:
pass
typelist = ', '.join(get_type_name(t) for t in union_params)
raise TypeError('type of {} must be one of ({}); got {} instead'.
format(argname, typelist, qualified_name(value)))
def check_class(argname: str, value, expected_type, memo: _TypeCheckMemo) -> None:
if not isclass(value):
raise TypeError('type of {} must be a type; got {} instead'.format(
argname, qualified_name(value)))
# Needed on Python 3.7+
if expected_type is Type:
return
expected_class = None
if hasattr(expected_type, "__args__") and expected_type.__args__:
expected_class = expected_type.__args__[0]
if expected_class:
if expected_class is Any:
return
elif isinstance(expected_class, TypeVar):
check_typevar(argname, value, expected_class, memo, True)
elif not issubclass(value, expected_class):
raise TypeError('{} must be a subclass of {}; got {} instead'.format(
argname, qualified_name(expected_class), qualified_name(value)))
def check_typevar(argname: str, value, typevar: TypeVar, memo: _TypeCheckMemo,
subclass_check: bool = False) -> None:
bound_type = resolve_forwardref(memo.typevars.get(typevar, typevar.__bound__), memo)
value_type = value if subclass_check else type(value)
subject = argname if subclass_check else 'type of ' + argname
if bound_type is None:
# The type variable hasn't been bound yet -- check that the given value matches the
# constraints of the type variable, if any
if typevar.__constraints__:
constraints = [resolve_forwardref(c, memo) for c in typevar.__constraints__]
if value_type not in constraints:
typelist = ', '.join(get_type_name(t) for t in constraints if t is not object)
raise TypeError('{} must be one of ({}); got {} instead'.
format(subject, typelist, qualified_name(value_type)))
elif typevar.__covariant__ or typevar.__bound__:
if not issubclass(value_type, bound_type):
raise TypeError(
'{} must be {} or one of its subclasses; got {} instead'.
format(subject, qualified_name(bound_type), qualified_name(value_type)))
elif typevar.__contravariant__:
if not issubclass(bound_type, value_type):
raise TypeError(
'{} must be {} or one of its superclasses; got {} instead'.
format(subject, qualified_name(bound_type), qualified_name(value_type)))
else: # invariant
if value_type is not bound_type:
raise TypeError(
'{} must be exactly {}; got {} instead'.
format(subject, qualified_name(bound_type), qualified_name(value_type)))
if typevar not in memo.typevars:
# Bind the type variable to a concrete type
memo.typevars[typevar] = value_type
def check_literal(argname: str, value, expected_type, memo: _TypeCheckMemo):
def get_args(literal):
try:
args = literal.__args__
except AttributeError:
# Instance of Literal from typing_extensions
args = literal.__values__
retval = []
for arg in args:
if isinstance(arg, Literal.__class__) or getattr(arg, '__origin__', None) is Literal:
# The first check works on py3.6 and lower, the second one on py3.7+
retval.extend(get_args(arg))
elif isinstance(arg, (int, str, bytes, bool, type(None), Enum)):
retval.append(arg)
else:
raise TypeError('Illegal literal value: {}'.format(arg))
return retval
final_args = tuple(get_args(expected_type))
if value not in final_args:
raise TypeError('the value of {} must be one of {}; got {} instead'.
format(argname, final_args, value))
def check_number(argname: str, value, expected_type):
if expected_type is complex and not isinstance(value, (complex, float, int)):
raise TypeError('type of {} must be either complex, float or int; got {} instead'.
format(argname, qualified_name(value.__class__)))
elif expected_type is float and not isinstance(value, (float, int)):
raise TypeError('type of {} must be either float or int; got {} instead'.
format(argname, qualified_name(value.__class__)))
def check_io(argname: str, value, expected_type):
if expected_type is TextIO:
if not isinstance(value, TextIOBase):
raise TypeError('type of {} must be a text based I/O object; got {} instead'.
format(argname, qualified_name(value.__class__)))
elif expected_type is BinaryIO:
if not isinstance(value, (RawIOBase, BufferedIOBase)):
raise TypeError('type of {} must be a binary I/O object; got {} instead'.
format(argname, qualified_name(value.__class__)))
elif not isinstance(value, IOBase):
raise TypeError('type of {} must be an I/O object; got {} instead'.
format(argname, qualified_name(value.__class__)))
def check_protocol(argname: str, value, expected_type):
# TODO: implement proper compatibility checking and support non-runtime protocols
if getattr(expected_type, '_is_runtime_protocol', False):
if not isinstance(value, expected_type):
raise TypeError('type of {} ({}) is not compatible with the {} protocol'.
format(argname, type(value).__qualname__, expected_type.__qualname__))
# Equality checks are applied to these
origin_type_checkers = {
AbstractSet: check_set,
Callable: check_callable,
collections.abc.Callable: check_callable,
dict: check_dict,
Dict: check_dict,
list: check_list,
List: check_list,
Sequence: check_sequence,
collections.abc.Sequence: check_sequence,
collections.abc.Set: check_set,
set: check_set,
Set: check_set,
tuple: check_tuple,
Tuple: check_tuple,
type: check_class,
Type: check_class,
Union: check_union
}
_subclass_check_unions = hasattr(Union, '__union_set_params__')
if Literal is not None:
origin_type_checkers[Literal] = check_literal
generator_origin_types = (Generator, collections.abc.Generator,
Iterator, collections.abc.Iterator,
Iterable, collections.abc.Iterable)
asyncgen_origin_types = (AsyncIterator, collections.abc.AsyncIterator,
AsyncIterable, collections.abc.AsyncIterable)
if AsyncGenerator is not None:
asyncgen_origin_types += (AsyncGenerator,)
if hasattr(collections.abc, 'AsyncGenerator'):
asyncgen_origin_types += (collections.abc.AsyncGenerator,)
def check_type(argname: str, value, expected_type, memo: Optional[_TypeCheckMemo] = None, *,
globals: Optional[Dict[str, Any]] = None,
locals: Optional[Dict[str, Any]] = None) -> None:
"""
Ensure that ``value`` matches ``expected_type``.
The types from the :mod:`typing` module do not support :func:`isinstance` or :func:`issubclass`
so a number of type specific checks are required. This function knows which checker to call
for which type.
:param argname: name of the argument to check; used for error messages
:param value: value to be checked against ``expected_type``
:param expected_type: a class or generic type instance
:param globals: dictionary of global variables to use for resolving forward references
(defaults to the calling frame's globals)
:param locals: dictionary of local variables to use for resolving forward references
(defaults to the calling frame's locals)
"""
if expected_type is Any or isinstance(value, Mock):
return
if expected_type is None:
# Only happens on < 3.6
expected_type = type(None)
if memo is None:
frame = sys._getframe(1)
if globals is None:
globals = frame.f_globals
if locals is None:
locals = frame.f_locals
memo = _TypeCheckMemo(globals, locals)
expected_type = resolve_forwardref(expected_type, memo)
origin_type = getattr(expected_type, '__origin__', None)
if origin_type is not None:
checker_func = origin_type_checkers.get(origin_type)
if checker_func:
checker_func(argname, value, expected_type, memo)
else:
check_type(argname, value, origin_type, memo)
elif isclass(expected_type):
if issubclass(expected_type, Tuple):
check_tuple(argname, value, expected_type, memo)
elif issubclass(expected_type, (float, complex)):
check_number(argname, value, expected_type)
elif _subclass_check_unions and issubclass(expected_type, Union):
check_union(argname, value, expected_type, memo)
elif isinstance(expected_type, TypeVar):
check_typevar(argname, value, expected_type, memo)
elif issubclass(expected_type, IO):
check_io(argname, value, expected_type)
elif issubclass(expected_type, dict) and hasattr(expected_type, '__annotations__'):
check_typed_dict(argname, value, expected_type, memo)
elif getattr(expected_type, '_is_protocol', False):
check_protocol(argname, value, expected_type)
else:
expected_type = (getattr(expected_type, '__extra__', None) or origin_type or
expected_type)
if expected_type is bytes:
# As per https://github.com/python/typing/issues/552
expected_type = (bytearray, bytes)
if not isinstance(value, expected_type):
raise TypeError(
'type of {} must be {}; got {} instead'.
format(argname, qualified_name(expected_type), qualified_name(value)))
elif isinstance(expected_type, TypeVar):
# Only happens on < 3.6
check_typevar(argname, value, expected_type, memo)
elif isinstance(expected_type, Literal.__class__):
# Only happens on < 3.7 when using Literal from typing_extensions
check_literal(argname, value, expected_type, memo)
elif (isfunction(expected_type) and
getattr(expected_type, "__module__", None) == "typing" and
getattr(expected_type, "__qualname__", None).startswith("NewType.") and
hasattr(expected_type, "__supertype__")):
# typing.NewType, should check against supertype (recursively)
return check_type(argname, value, expected_type.__supertype__, memo)
def check_return_type(retval, memo: Optional[_CallMemo] = None) -> bool:
"""
Check that the return value is compatible with the return value annotation in the function.
:param retval: the value about to be returned from the call
:return: ``True``
:raises TypeError: if there is a type mismatch
"""
if memo is None:
# faster than inspect.currentframe(), but not officially
# supported in all python implementations
frame = sys._getframe(1)
try:
func = find_function(frame)
except LookupError:
return True # This can happen with the Pydev/PyCharm debugger extension installed
memo = _CallMemo(func, frame.f_locals)
if 'return' in memo.type_hints:
if memo.type_hints['return'] is NoReturn:
raise TypeError('{}() was declared never to return but it did'.format(memo.func_name))
try:
check_type('the return value', retval, memo.type_hints['return'], memo)
except TypeError as exc: # suppress unnecessarily long tracebacks
raise exc from None
return True
def check_argument_types(memo: Optional[_CallMemo] = None) -> bool:
"""
Check that the argument values match the annotated types.
Unless both ``args`` and ``kwargs`` are provided, the information will be retrieved from
the previous stack frame (ie. from the function that called this).
:return: ``True``
:raises TypeError: if there is an argument type mismatch
"""
if memo is None:
# faster than inspect.currentframe(), but not officially
# supported in all python implementations
frame = sys._getframe(1)
try:
func = find_function(frame)
except LookupError:
return True # This can happen with the Pydev/PyCharm debugger extension installed
memo = _CallMemo(func, frame.f_locals)
for argname, expected_type in memo.type_hints.items():
if argname != 'return' and argname in memo.arguments:
value = memo.arguments[argname]
description = 'argument "{}"'.format(argname)
try:
check_type(description, value, expected_type, memo)
except TypeError as exc: # suppress unnecessarily long tracebacks
raise exc from None
return True
class TypeCheckedGenerator:
def __init__(self, wrapped: Generator, memo: _CallMemo):
rtype_args = []
if hasattr(memo.type_hints['return'], "__args__"):
rtype_args = memo.type_hints['return'].__args__
self.__wrapped = wrapped
self.__memo = memo
self.__yield_type = rtype_args[0] if rtype_args else Any
self.__send_type = rtype_args[1] if len(rtype_args) > 1 else Any
self.__return_type = rtype_args[2] if len(rtype_args) > 2 else Any
self.__initialized = False
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def __getattr__(self, name: str) -> Any:
return getattr(self.__wrapped, name)
def throw(self, *args):
return self.__wrapped.throw(*args)
def close(self):
self.__wrapped.close()
def send(self, obj):
if self.__initialized:
check_type('value sent to generator', obj, self.__send_type, memo=self.__memo)
else:
self.__initialized = True
try:
value = self.__wrapped.send(obj)
except StopIteration as exc:
check_type('return value', exc.value, self.__return_type, memo=self.__memo)
raise
check_type('value yielded from generator', value, self.__yield_type, memo=self.__memo)
return value
class TypeCheckedAsyncGenerator:
def __init__(self, wrapped: AsyncGenerator, memo: _CallMemo):
rtype_args = memo.type_hints['return'].__args__
self.__wrapped = wrapped
self.__memo = memo
self.__yield_type = rtype_args[0]
self.__send_type = rtype_args[1] if len(rtype_args) > 1 else Any
self.__initialized = False
async def __aiter__(self):
return self
def __anext__(self):
return self.asend(None)
def __getattr__(self, name: str) -> Any:
return getattr(self.__wrapped, name)
def athrow(self, *args):
return self.__wrapped.athrow(*args)
def aclose(self):
return self.__wrapped.aclose()
async def asend(self, obj):
if self.__initialized:
check_type('value sent to generator', obj, self.__send_type, memo=self.__memo)
else:
self.__initialized = True
value = await self.__wrapped.asend(obj)
check_type('value yielded from generator', value, self.__yield_type, memo=self.__memo)
return value
@overload
def typechecked(*, always: bool = False) -> Callable[[T_CallableOrType], T_CallableOrType]:
...
@overload
def typechecked(func: T_CallableOrType, *, always: bool = False) -> T_CallableOrType:
...
def typechecked(func=None, *, always=False, _localns: Optional[Dict[str, Any]] = None):
"""
Perform runtime type checking on the arguments that are passed to the wrapped function.
The return value is also checked against the return annotation if any.
If the ``__debug__`` global variable is set to ``False``, no wrapping and therefore no type
checking is done, unless ``always`` is ``True``.
This can also be used as a class decorator. This will wrap all type annotated methods in the
class with this decorator.
:param func: the function or class to enable type checking for
:param always: ``True`` to enable type checks even in optimized mode
"""
if func is None:
return partial(typechecked, always=always, _localns=_localns)
if not __debug__ and not always: # pragma: no cover
return func
if isclass(func):
prefix = func.__qualname__ + '.'
for key, attr in func.__dict__.items():
if inspect.isfunction(attr) or inspect.ismethod(attr) or inspect.isclass(attr):
if attr.__qualname__.startswith(prefix) and getattr(attr, '__annotations__', None):
setattr(func, key, typechecked(attr, always=always, _localns=func.__dict__))
elif isinstance(attr, (classmethod, staticmethod)):
if getattr(attr.__func__, '__annotations__', None):
wrapped = typechecked(attr.__func__, always=always, _localns=func.__dict__)
setattr(func, key, type(attr)(wrapped))
return func
# Find the frame in which the function was declared, for resolving forward references later
if _localns is None:
_localns = sys._getframe(1).f_locals
# Find either the first Python wrapper or the actual function
python_func = inspect.unwrap(func, stop=lambda f: hasattr(f, '__code__'))
if not getattr(func, '__annotations__', None):
warn('no type annotations present -- not typechecking {}'.format(function_name(func)))
return func
def wrapper(*args, **kwargs):
memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)
check_argument_types(memo)
retval = func(*args, **kwargs)
check_return_type(retval, memo)
# If a generator is returned, wrap it if its yield/send/return types can be checked
if inspect.isgenerator(retval) or isasyncgen(retval):
return_type = memo.type_hints.get('return')
if return_type:
origin = getattr(return_type, '__origin__', None)
if origin in generator_origin_types:
return TypeCheckedGenerator(retval, memo)
elif origin is not None and origin in asyncgen_origin_types:
return TypeCheckedAsyncGenerator(retval, memo)
return retval
async def async_wrapper(*args, **kwargs):
memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs)
check_argument_types(memo)
retval = await func(*args, **kwargs)
check_return_type(retval, memo)
return retval
if inspect.iscoroutinefunction(func):
if python_func.__code__ is not async_wrapper.__code__:
return wraps(func)(async_wrapper)
else:
if python_func.__code__ is not wrapper.__code__:
return wraps(func)(wrapper)
# the target callable was already wrapped
return func
class TypeWarning(UserWarning):
"""
A warning that is emitted when a type check fails.
:ivar str event: ``call`` or ``return``
:ivar Callable func: the function in which the violation occurred (the called function if event
is ``call``, or the function where a value of the wrong type was returned from if event is
``return``)
:ivar str error: the error message contained by the caught :class:`TypeError`
:ivar frame: the frame in which the violation occurred
"""
__slots__ = ('func', 'event', 'message', 'frame')
def __init__(self, memo: Optional[_CallMemo], event: str, frame,
exception: Union[str, TypeError]): # pragma: no cover
self.func = memo.func
self.event = event
self.error = str(exception)
self.frame = frame
if self.event == 'call':
caller_frame = self.frame.f_back
event = 'call to {}() from {}:{}'.format(
function_name(self.func), caller_frame.f_code.co_filename, caller_frame.f_lineno)
else:
event = 'return from {}() at {}:{}'.format(
function_name(self.func), self.frame.f_code.co_filename, self.frame.f_lineno)
super().__init__('[{thread_name}] {event}: {self.error}'.format(
thread_name=threading.current_thread().name, event=event, self=self))
@property
def stack(self):
"""Return the stack where the last frame is from the target function."""
return extract_stack(self.frame)
def print_stack(self, file: TextIO = None, limit: int = None) -> None:
"""
Print the traceback from the stack frame where the target function was run.
:param file: an open file to print to (prints to stdout if omitted)
:param limit: the maximum number of stack frames to print
"""
print_stack(self.frame, limit, file)
class TypeChecker:
"""
A type checker that collects type violations by hooking into :func:`sys.setprofile`.
:param packages: list of top level modules and packages or modules to include for type checking
:param all_threads: ``True`` to check types in all threads created while the checker is
running, ``False`` to only check in the current one
:param forward_refs_policy: how to handle unresolvable forward references in annotations
.. deprecated:: 2.6
Use :func:`~.importhook.install_import_hook` instead. This class will be removed in v3.0.
"""
def __init__(self, packages: Union[str, Sequence[str]], *, all_threads: bool = True,
forward_refs_policy: ForwardRefPolicy = ForwardRefPolicy.ERROR):
assert check_argument_types()
warn('TypeChecker has been deprecated and will be removed in v3.0. '
'Use install_import_hook() or the pytest plugin instead.', DeprecationWarning)
self.all_threads = all_threads
self.annotation_policy = forward_refs_policy
self._call_memos = {} # type: Dict[Any, _CallMemo]
self._previous_profiler = None
self._previous_thread_profiler = None
self._active = False
if isinstance(packages, str):
self._packages = (packages,)
else:
self._packages = tuple(packages)
@property
def active(self) -> bool: