Skip to content

Commit

Permalink
Add get_overloads() (#1140)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
JelleZijlstra and AlexWaygood committed Apr 16, 2022
1 parent 2acaa5a commit 35dff91
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 3 deletions.
5 changes: 4 additions & 1 deletion typing_extensions/CHANGELOG
@@ -1,6 +1,9 @@
# Unreleased

- Add `typing.assert_type`. Backport from bpo-46480.
- Add `typing_extensions.get_overloads` and
`typing_extensions.clear_overloads`, and add registry support to
`typing_extensions.overload`. Backport from python/cpython#89263.
- Add `typing_extensions.assert_type`. Backport from bpo-46480.
- Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner).

# Release 4.1.1 (February 13, 2022)
Expand Down
6 changes: 6 additions & 0 deletions typing_extensions/README.rst
Expand Up @@ -47,6 +47,8 @@ This module currently contains the following:

- ``assert_never``
- ``assert_type``
- ``clear_overloads``
- ``get_overloads``
- ``LiteralString`` (see PEP 675)
- ``Never``
- ``NotRequired`` (see PEP 655)
Expand Down Expand Up @@ -122,6 +124,10 @@ Certain objects were changed after they were added to ``typing``, and
Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs``
in 3.9.
- ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute.
- ``@overload`` was changed in Python 3.11 to make function overloads
introspectable at runtime. In order to access overloads with
``typing_extensions.get_overloads()``, you must use
``@typing_extensions.overload``.

There are a few types whose interface was modified between different
versions of typing. For example, ``typing.Sequence`` was modified to
Expand Down
74 changes: 73 additions & 1 deletion typing_extensions/src/test_typing_extensions.py
Expand Up @@ -3,13 +3,15 @@
import abc
import contextlib
import collections
from collections import defaultdict
import collections.abc
from functools import lru_cache
import inspect
import pickle
import subprocess
import types
from unittest import TestCase, main, skipUnless, skipIf
from unittest.mock import patch
from test import ann_module, ann_module2, ann_module3
import typing
from typing import TypeVar, Optional, Union, Any, AnyStr
Expand All @@ -21,9 +23,10 @@
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard
from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload

# Flags used to mark tests that only apply after a specific
# version of the typing module.
Expand Down Expand Up @@ -403,6 +406,20 @@ def test_no_multiple_subscripts(self):
Literal[1][1]


class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...


if TYPING_3_11_0:
registry_holder = typing
else:
registry_holder = typing_extensions


class OverloadTests(BaseTestCase):

def test_overload_fails(self):
Expand All @@ -424,6 +441,61 @@ def blah():

blah()

def set_up_overloads(self):
def blah():
pass

overload1 = blah
overload(blah)

def blah():
pass

overload2 = blah
overload(blah)

def blah():
pass

return blah, [overload1, overload2]

# Make sure we don't clear the global overload registry
@patch(
f"{registry_holder.__name__}._overload_registry",
defaultdict(lambda: defaultdict(dict))
)
def test_overload_registry(self):
registry = registry_holder._overload_registry
# The registry starts out empty
self.assertEqual(registry, {})

impl, overloads = self.set_up_overloads()
self.assertNotEqual(registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)

def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])

# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(registry, {})
self.assertEqual(get_overloads(impl), [])

# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(registry, {})

def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()

self.assertEqual(list(get_overloads(impl)), overloads)


class AssertTypeTests(BaseTestCase):

Expand Down
70 changes: 69 additions & 1 deletion typing_extensions/src/typing_extensions.py
@@ -1,6 +1,7 @@
import abc
import collections
import collections.abc
import functools
import operator
import sys
import types as _types
Expand Down Expand Up @@ -46,7 +47,9 @@
'Annotated',
'assert_never',
'assert_type',
'clear_overloads',
'dataclass_transform',
'get_overloads',
'final',
'get_args',
'get_origin',
Expand Down Expand Up @@ -249,7 +252,72 @@ def __getitem__(self, parameters):


_overload_dummy = typing._overload_dummy # noqa
overload = typing.overload


if hasattr(typing, "get_overloads"): # 3.11+
overload = typing.overload
get_overloads = typing.get_overloads
clear_overloads = typing.clear_overloads
else:
# {module: {qualname: {firstlineno: func}}}
_overload_registry = collections.defaultdict(
functools.partial(collections.defaultdict, dict)
)

def overload(func):
"""Decorator for overloaded functions/methods.
In a stub file, place two or more stub definitions for the same
function in a row, each decorated with @overload. For example:
@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...
In a non-stub file (i.e. a regular .py file), do the same but
follow it with an implementation. The implementation should *not*
be decorated with @overload. For example:
@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...
def utf8(value):
# implementation goes here
The overloads for a function can be retrieved at runtime using the
get_overloads() function.
"""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
try:
_overload_registry[f.__module__][f.__qualname__][
f.__code__.co_firstlineno
] = func
except AttributeError:
# Not a normal function; ignore.
pass
return _overload_dummy

def get_overloads(func):
"""Return all defined overloads for *func* as a sequence."""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
if f.__module__ not in _overload_registry:
return []
mod_dict = _overload_registry[f.__module__]
if f.__qualname__ not in mod_dict:
return []
return list(mod_dict[f.__qualname__].values())

def clear_overloads():
"""Clear all overloads in the registry."""
_overload_registry.clear()


# This is not a real generic class. Don't use outside annotations.
Expand Down

0 comments on commit 35dff91

Please sign in to comment.