Skip to content

Commit

Permalink
gh-89263: Add typing.get_overloads (GH-31716)
Browse files Browse the repository at this point in the history
Based on suggestions by Guido van Rossum, Spencer Brown, and Alex Waygood.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Guido van Rossum <gvanrossum@gmail.com>
Co-authored-by: Ken Jin <kenjin4096@gmail.com>
  • Loading branch information
4 people committed Apr 16, 2022
1 parent 9300b6d commit 055760e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 4 deletions.
29 changes: 29 additions & 0 deletions Doc/library/typing.rst
Expand Up @@ -2407,6 +2407,35 @@ Functions and decorators

See :pep:`484` for details and comparison with other typing semantics.

.. versionchanged:: 3.11
Overloaded functions can now be introspected at runtime using
:func:`get_overloads`.


.. function:: get_overloads(func)

Return a sequence of :func:`@overload <overload>`-decorated definitions for
*func*. *func* is the function object for the implementation of the
overloaded function. For example, given the definition of ``process`` in
the documentation for :func:`@overload <overload>`,
``get_overloads(process)`` will return a sequence of three function objects
for the three defined overloads. If called on a function with no overloads,
``get_overloads`` returns an empty sequence.

``get_overloads`` can be used for introspecting an overloaded function at
runtime.

.. versionadded:: 3.11


.. function:: clear_overloads()

Clear all registered overloads in the internal registry. This can be used
to reclaim the memory used by the registry.

.. versionadded:: 3.11


.. decorator:: final

A decorator to indicate to type checkers that the decorated method
Expand Down
72 changes: 68 additions & 4 deletions Lib/test/test_typing.py
@@ -1,15 +1,18 @@
import contextlib
import collections
from collections import defaultdict
from functools import lru_cache
import inspect
import pickle
import re
import sys
import warnings
from unittest import TestCase, main, skipUnless, skip
from unittest.mock import patch
from copy import copy, deepcopy

from typing import Any, NoReturn, Never, assert_never
from typing import overload, get_overloads, clear_overloads
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal
Expand Down Expand Up @@ -3890,11 +3893,22 @@ def test_or(self):
self.assertEqual("x" | X, Union["x", X])


@lru_cache()
def cached_func(x, y):
return 3 * x + y


class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...


class OverloadTests(BaseTestCase):

def test_overload_fails(self):
from typing import overload

with self.assertRaises(RuntimeError):

@overload
Expand All @@ -3904,8 +3918,6 @@ def blah():
blah()

def test_overload_succeeds(self):
from typing import overload

@overload
def blah():
pass
Expand All @@ -3915,6 +3927,58 @@ def blah():

blah()

def set_up_overloads(self):
def blah():
pass

overload1 = blah
overload(blah)

def blah():
pass

overload2 = blah
overload(blah)

def blah():
pass

return blah, [overload1, overload2]

# Make sure we don't clear the global overload registry
@patch("typing._overload_registry",
defaultdict(lambda: defaultdict(dict)))
def test_overload_registry(self):
# The registry starts out empty
self.assertEqual(typing._overload_registry, {})

impl, overloads = self.set_up_overloads()
self.assertNotEqual(typing._overload_registry, {})
self.assertEqual(list(get_overloads(impl)), overloads)

def some_other_func(): pass
overload(some_other_func)
other_overload = some_other_func
def some_other_func(): pass
self.assertEqual(list(get_overloads(some_other_func)), [other_overload])

# Make sure that after we clear all overloads, the registry is
# completely empty.
clear_overloads()
self.assertEqual(typing._overload_registry, {})
self.assertEqual(get_overloads(impl), [])

# Querying a function with no overloads shouldn't change the registry.
def the_only_one(): pass
self.assertEqual(get_overloads(the_only_one), [])
self.assertEqual(typing._overload_registry, {})

def test_overload_registry_repeated(self):
for _ in range(2):
impl, overloads = self.set_up_overloads()

self.assertEqual(list(get_overloads(impl)), overloads)


# Definitions needed for features introduced in Python 3.6

Expand Down
34 changes: 34 additions & 0 deletions Lib/typing.py
Expand Up @@ -21,6 +21,7 @@

from abc import abstractmethod, ABCMeta
import collections
from collections import defaultdict
import collections.abc
import contextlib
import functools
Expand Down Expand Up @@ -121,9 +122,11 @@ def _idfunc(_, x):
'assert_type',
'assert_never',
'cast',
'clear_overloads',
'final',
'get_args',
'get_origin',
'get_overloads',
'get_type_hints',
'is_typeddict',
'LiteralString',
Expand Down Expand Up @@ -2450,6 +2453,10 @@ def _overload_dummy(*args, **kwds):
"by an implementation that is not @overload-ed.")


# {module: {qualname: {firstlineno: func}}}
_overload_registry = defaultdict(functools.partial(defaultdict, dict))


def overload(func):
"""Decorator for overloaded functions/methods.
Expand All @@ -2475,10 +2482,37 @@ def utf8(value: bytes) -> bytes: ...
def utf8(value: str) -> bytes: ...
def utf8(value):
# implementation goes here
The overloads for a function can be retrieved at runtime using the
get_overloads() function.
"""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
try:
_overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func
except AttributeError:
# Not a normal function; ignore.
pass
return _overload_dummy


def get_overloads(func):
"""Return all defined overloads for *func* as a sequence."""
# classmethod and staticmethod
f = getattr(func, "__func__", func)
if f.__module__ not in _overload_registry:
return []
mod_dict = _overload_registry[f.__module__]
if f.__qualname__ not in mod_dict:
return []
return list(mod_dict[f.__qualname__].values())


def clear_overloads():
"""Clear all overloads in the registry."""
_overload_registry.clear()


def final(f):
"""A decorator to indicate final methods and final classes.
Expand Down
@@ -0,0 +1,2 @@
Add :func:`typing.get_overloads` and :func:`typing.clear_overloads`.
Patch by Jelle Zijlstra.

0 comments on commit 055760e

Please sign in to comment.