Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Interaction with get_type_hints and Annotated <3.11. #312

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/test_typing_extensions.py
Expand Up @@ -5255,7 +5255,7 @@ def test_typing_extensions_defers_when_possible(self):
if sys.version_info < (3, 10, 1):
exclude |= {"Literal"}
if sys.version_info < (3, 11):
exclude |= {'final', 'Any', 'NewType'}
exclude |= {'final', 'Any', 'NewType', 'ForwardRef'}
if sys.version_info < (3, 12):
exclude |= {
'SupportsAbs', 'SupportsBytes',
Expand Down
264 changes: 254 additions & 10 deletions src/typing_extensions.py
Expand Up @@ -1042,7 +1042,143 @@ def greet(name: str) -> None:

if hasattr(typing, "Required"): # 3.11+
get_type_hints = typing.get_type_hints
ForwardRef = typing.ForwardRef

else: # <=3.10
if hasattr(_types, "GenericAlias"):
_generic_alias_bases = (typing._GenericAlias, _types.GenericAlias)
else:
_generic_alias_bases = (typing._GenericAlias,)


def _is_union_type(t):
return hasattr(_types, "UnionType") and isinstance(t, _types.UnionType)

class ForwardRef(typing._Final, _root=True):
"""Internal wrapper to hold a forward reference."""

__slots__ = ('__forward_arg__', '__forward_code__',
'__forward_evaluated__', '__forward_value__',
'__forward_is_argument__', '__forward_is_class__',
'__forward_module__')

def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
if not isinstance(arg, str):
raise TypeError(f"Forward reference must be a string -- got {arg!r}")

# If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`.
# Unfortunately, this isn't a valid expression on its own, so we
# do the unpacking manually.
if arg[0] == '*':
arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0]
else:
arg_to_compile = arg
try:
code = compile(arg_to_compile, '<string>', 'eval')
except SyntaxError:
raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")

self.__forward_arg__ = arg
self.__forward_code__ = code
self.__forward_evaluated__ = False
self.__forward_value__ = None
self.__forward_is_argument__ = is_argument
self.__forward_is_class__ = is_class
self.__forward_module__ = module

def _evaluate(self, globalns, localns, recursive_guard):
if self.__forward_arg__ in recursive_guard:
return self
if not self.__forward_evaluated__ or localns is not globalns:
if globalns is None and localns is None:
globalns = localns = {}
elif globalns is None:
globalns = localns
elif localns is None:
localns = globalns
if self.__forward_module__ is not None:
globalns = getattr(
sys.modules.get(self.__forward_module__, None), '__dict__', globalns
)
type_ = _type_check(
eval(self.__forward_code__, globalns, localns),
"Forward references must evaluate to types.",
is_argument=self.__forward_is_argument__,
allow_special_forms=self.__forward_is_class__,
)
self.__forward_value__ = _eval_type(
type_, globalns, localns, recursive_guard | {self.__forward_arg__}
)
self.__forward_evaluated__ = True
return self.__forward_value__

def __eq__(self, other):
if not isinstance(other, (ForwardRef, typing.ForwardRef)):
return NotImplemented
if self.__forward_evaluated__ and other.__forward_evaluated__:
return (self.__forward_arg__ == other.__forward_arg__ and
self.__forward_value__ == other.__forward_value__)
return (self.__forward_arg__ == other.__forward_arg__ and
self.__forward_module__ == other.__forward_module__)

def __hash__(self):
return hash((self.__forward_arg__, self.__forward_module__))

def __or__(self, other):
return Union[self, other]

def __ror__(self, other):
return Union[other, self]

def __repr__(self):
if self.__forward_module__ is None:
module_repr = ''
else:
module_repr = f', module={self.__forward_module__!r}'
return f'ForwardRef({self.__forward_arg__!r}{module_repr})'


def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False):
"""Check that the argument is a type, and return it (internal helper).

As a special case, accept None and return type(None) instead. Also wrap strings
into ForwardRef instances. Consider several corner cases, for example plain
special forms like Union are not valid, while Union[int, str] is OK, etc.
The msg argument is a human-readable error message, e.g.::

"Union[arg, ...]: arg should be a type."

We append the repr() of the actual value (truncated to 100 chars).
"""
invalid_generic_forms = (Generic, Protocol)
if not allow_special_forms:
invalid_generic_forms += (ClassVar,)
if is_argument:
invalid_generic_forms += (Final,)

arg = _type_convert(arg, module=module, allow_special_forms=allow_special_forms)
if (isinstance(arg, typing._GenericAlias) and
arg.__origin__ in invalid_generic_forms):
raise TypeError(f"{arg} is not valid as type argument")
if arg in (Any, LiteralString, NoReturn, Never, Self, TypeAlias):
return arg
if allow_special_forms and arg in (ClassVar, Final):
return arg
if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
raise TypeError(f"Plain {arg} is not valid as type argument")
if type(arg) is tuple:
raise TypeError(f"{msg} Got {arg!r:.100}.")
return arg


def _type_convert(arg, module=None, *, allow_special_forms=False):
"""For converting None to type(None), and strings to ForwardRef."""
if arg is None:
return type(None)
if isinstance(arg, str):
return ForwardRef(arg, module=module, is_class=allow_special_forms)
return arg

# replaces _strip_annotations()
def _strip_extras(t):
"""Strips Annotated, Required and NotRequired from a given type."""
Expand All @@ -1060,7 +1196,7 @@ def _strip_extras(t):
if stripped_args == t.__args__:
return t
return _types.GenericAlias(t.__origin__, stripped_args)
if hasattr(_types, "UnionType") and isinstance(t, _types.UnionType):
if _is_union_type(t):
stripped_args = tuple(_strip_extras(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
Expand Down Expand Up @@ -1100,15 +1236,124 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
if hasattr(typing, "Annotated"): # 3.9+
hint = typing.get_type_hints(
obj, globalns=globalns, localns=localns, include_extras=True
if getattr(obj, "__no_type_check__", None):
return {}
# Classes require a special treatment.
if isinstance(obj, type):
hints = {}
for base in reversed(obj.__mro__):
if globalns is None:
base_globals = getattr(
sys.modules.get(base.__module__, None), "__dict__", {}
)
else:
base_globals = globalns
ann = base.__dict__.get("__annotations__", {})
if isinstance(ann, _types.GetSetDescriptorType):
ann = {}
base_locals = dict(vars(base)) if localns is None else localns
if localns is None and globalns is None:
# This is surprising, but required. Before Python 3.10,
# get_type_hints only evaluated the globalns of
# a class. To maintain backwards compatibility, we reverse
# the globalns and localns order so that eval() looks into
# *base_globals* first rather than *base_locals*.
# This only affects ForwardRefs.
base_globals, base_locals = base_locals, base_globals
for name, value in ann.items():
if value is None:
value = type(None)
if isinstance(value, str):
value = ForwardRef(value, is_argument=False, is_class=True)
value = _eval_type(value, base_globals, base_locals)
hints[name] = value
return (
hints
if include_extras
else {k: _strip_extras(t) for k, t in hints.items()}
)
else: # 3.8
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
if include_extras:
return hint
return {k: _strip_extras(t) for k, t in hint.items()}

if globalns is None:
if isinstance(obj, _types.ModuleType):
globalns = obj.__dict__
else:
nsobj = obj
# Find globalns for the unwrapped object.
while hasattr(nsobj, "__wrapped__"):
nsobj = nsobj.__wrapped__
globalns = getattr(nsobj, "__globals__", {})
if localns is None:
localns = globalns
elif localns is None:
localns = globalns
hints = getattr(obj, "__annotations__", None)
if hints is None:
# Return empty annotations for something that _could_ have them.
if isinstance(obj, typing._allowed_types):
return {}
else:
raise TypeError(
"{!r} is not a module, class, method, " "or function.".format(obj)
)
hints = dict(hints)
for name, value in hints.items():
if value is None:
value = type(None)
if isinstance(value, str):
# class-level forward refs were handled above, this must be either
# a module-level annotation or a function argument annotation
value = ForwardRef(
value,
is_argument=not isinstance(obj, _types.ModuleType),
is_class=False,
)
hints[name] = _eval_type(value, globalns, localns)
return (
hints if include_extras else {k: _strip_extras(t) for k, t in hints.items()}
)

def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
if isinstance(t, (ForwardRef, typing.ForwardRef)):
if not hasattr(t, '__forward_is_class__'):
return t._evaluate(globalns, localns)
return t._evaluate(globalns, localns, recursive_guard)


if isinstance(t, _generic_alias_bases) or _is_union_type(t):
if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias):
args = tuple(
ForwardRef(arg) if isinstance(arg, str) else arg
for arg in t.__args__
)
is_unpacked = t.__unpacked__
if _should_unflatten_callable_args(t, args):
t = t.__origin__[(args[:-1], args[-1])]
else:
t = t.__origin__[args]
if is_unpacked:
t = Unpack[t]
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
if ev_args == t.__args__:
return t
if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias):
return typing.GenericAlias(t.__origin__, ev_args)
if _is_union_type(t):
return functools.reduce(operator.or_, ev_args)
else:
return t.copy_with(ev_args)
return t

def _should_unflatten_callable_args(typ, args):
return (
typ.__origin__ is collections.abc.Callable
and not (len(args) == 2 and _is_param_expr(args[0]))
)

def _is_param_expr(arg):
return arg is ... or isinstance(arg,
(tuple, list, ParamSpec, _ConcatenateGenericAlias))




# Python 3.9+ has PEP 593 (Annotated)
Expand Down Expand Up @@ -3021,7 +3266,6 @@ def __eq__(self, other: object) -> bool:
Collection = typing.Collection
Container = typing.Container
Dict = typing.Dict
ForwardRef = typing.ForwardRef
FrozenSet = typing.FrozenSet
Generator = typing.Generator
Generic = typing.Generic
Expand Down