Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_overloads() #1140

Merged
merged 5 commits into from Apr 16, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion typing_extensions/CHANGELOG
@@ -1,6 +1,9 @@
# Unreleased

- Add `typing.assert_type`. Backport from bpo-46480.
- Add `typing_extensions.get_overloads` and
`typing_extensions.clear_overloads`, and add registry support to
`typing_extensions.overload`. Backport from python/cpython#89263.
- Add `typing_extensions.assert_type`. Backport from bpo-46480.
- Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner).

# Release 4.1.1 (February 13, 2022)
Expand Down
5 changes: 5 additions & 0 deletions typing_extensions/README.rst
Expand Up @@ -47,6 +47,8 @@ This module currently contains the following:

- ``assert_never``
- ``assert_type``
- ``clear_overloads``
- ``get_overloads``
- ``LiteralString`` (see PEP 675)
- ``Never``
- ``NotRequired`` (see PEP 655)
Expand Down Expand Up @@ -122,6 +124,9 @@ Certain objects were changed after they were added to ``typing``, and
Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs``
in 3.9.
- ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute.
- ``@overload`` was changed in Python 3.11 to register overload function.
In order to access overloads with ``typing_extensions.get_overloads()``,
you must use ``@typing_extensions.overload``.
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

There are a few types whose interface was modified between different
versions of typing. For example, ``typing.Sequence`` was modified to
Expand Down
74 changes: 73 additions & 1 deletion typing_extensions/src/test_typing_extensions.py
Expand Up @@ -3,13 +3,15 @@
import abc
import contextlib
import collections
from collections import defaultdict
import collections.abc
from functools import lru_cache
import inspect
import pickle
import subprocess
import types
from unittest import TestCase, main, skipUnless, skipIf
from unittest.mock import patch
from test import ann_module, ann_module2, ann_module3
import typing
from typing import TypeVar, Optional, Union, Any, AnyStr
Expand All @@ -21,9 +23,10 @@
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard
from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload

# Flags used to mark tests that only apply after a specific
# version of the typing module.
Expand Down Expand Up @@ -403,6 +406,20 @@ def test_no_multiple_subscripts(self):
Literal[1][1]


class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...


if TYPING_3_11_0:
registry_holder = typing
else:
registry_holder = typing_extensions


class OverloadTests(BaseTestCase):

def test_overload_fails(self):
Expand All @@ -424,6 +441,61 @@ def blah():

blah()

def set_up_overloads(self):
def blah():
pass

overload1 = blah
overload(blah)

def blah():
pass

overload2 = blah
overload(blah)

def blah():
pass

return blah, [overload1, overload2]

# Make sure we don't clear the global overload registry
@patch(
f"{registry_holder.__name__}._overload_registry",
defaultdict(lambda: defaultdict(dict))
)
def test_overload_registry(self):
registry = registry_holder._overload_registry
# The registry starts out empty
self.assertEqual(registry, {})

impl, overloads = self.set_up_overloads()
self.assertNotEqual(registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)

def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])

# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(registry, {})
self.assertEqual(get_overloads(impl), [])

# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(registry, {})

def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()

self.assertEqual(list(get_overloads(impl)), overloads)


class AssertTypeTests(BaseTestCase):

Expand Down
70 changes: 69 additions & 1 deletion typing_extensions/src/typing_extensions.py
@@ -1,6 +1,7 @@
import abc
import collections
import collections.abc
import functools
import operator
import sys
import types as _types
Expand Down Expand Up @@ -46,7 +47,9 @@
'Annotated',
'assert_never',
'assert_type',
'clear_overloads',
'dataclass_transform',
'get_overloads',
'final',
'get_args',
'get_origin',
Expand Down Expand Up @@ -249,7 +252,72 @@ def __getitem__(self, parameters):


_overload_dummy = typing._overload_dummy # noqa
overload = typing.overload


if hasattr(typing, "get_overloads"): # 3.11+
overload = typing.overload
get_overloads = typing.get_overloads
clear_overloads = typing.clear_overloads
else:
# {module: {qualname: {firstlineno: func}}}
_overload_registry = collections.defaultdict(
functools.partial(collections.defaultdict, dict)
)

def overload(func):
"""Decorator for overloaded functions/methods.

In a stub file, place two or more stub definitions for the same
function in a row, each decorated with @overload. For example:

@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...

In a non-stub file (i.e. a regular .py file), do the same but
follow it with an implementation. The implementation should *not*
be decorated with @overload. For example:

@overload
def utf8(value: None) -> None: ...
@overload
def utf8(value: bytes) -> bytes: ...
@overload
def utf8(value: str) -> bytes: ...
def utf8(value):
# implementation goes here

The overloads for a function can be retrieved at runtime using the
get_overloads() function.
"""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
try:
_overload_registry[f.__module__][f.__qualname__][
f.__code__.co_firstlineno
] = func
except AttributeError:
# Not a normal function; ignore.
pass
return _overload_dummy

def get_overloads(func):
"""Return all defined overloads for *func* as a sequence."""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
if f.__module__ not in _overload_registry:
return []
mod_dict = _overload_registry[f.__module__]
if f.__qualname__ not in mod_dict:
return []
return list(mod_dict[f.__qualname__].values())

def clear_overloads():
"""Clear all overloads in the registry."""
_overload_registry.clear()


# This is not a real generic class. Don't use outside annotations.
Expand Down