Skip to content

Commit

Permalink
Backport performance improvements to runtime-checkable protocols (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Apr 12, 2023
1 parent 4dfc5c5 commit 6c93956
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 21 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
(originally by Yurii Karabas), ensuring that `isinstance()` calls on
protocols raise `TypeError` when the protocol is not decorated with
`@runtime_checkable`. Patch by Alex Waygood.
- Backport several significant performance improvements to runtime-checkable
protocols that have been made in Python 3.12 (see
https://github.com/python/cpython/issues/74690 for details). Patch by Alex
Waygood.

A side effect of one of the performance improvements is that the members of
a runtime-checkable protocol are now considered “frozen” at runtime as soon
as the class has been created. Monkey-patching attributes onto a
runtime-checkable protocol will still work, but will have no impact on
`isinstance()` checks comparing objects to the protocol. See
["What's New in Python 3.12"](https://docs.python.org/3.12/whatsnew/3.12.html#typing)
for more details.

# Release 4.5.0 (February 14, 2023)

Expand Down
4 changes: 3 additions & 1 deletion src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3452,9 +3452,11 @@ def test_typing_extensions_defers_when_possible(self):
'is_typeddict',
}
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin', 'Protocol', 'runtime_checkable'}
exclude |= {'get_args', 'get_origin'}
if sys.version_info < (3, 11):
exclude |= {'final', 'NamedTuple', 'Any'}
if sys.version_info < (3, 12):
exclude |= {'Protocol', 'runtime_checkable'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
55 changes: 35 additions & 20 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def clear_overloads():
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
"__subclasshook__", "__orig_class__", "__init__", "__new__",
"__protocol_attrs__", "__callable_proto_members_only__",
}

if sys.version_info < (3, 8):
Expand All @@ -420,19 +421,15 @@ def clear_overloads():
def _get_protocol_attrs(cls):
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
if base.__name__ in {'Protocol', 'Generic'}:
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
for attr in (*base.__dict__, *annotations):
if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS):
attrs.add(attr)
return attrs


def _is_callable_members_only(cls):
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))


def _maybe_adjust_parameters(cls):
"""Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__.
Expand All @@ -442,7 +439,7 @@ def _maybe_adjust_parameters(cls):
"""
tvars = []
if '__orig_bases__' in cls.__dict__:
tvars = typing._collect_type_vars(cls.__orig_bases__)
tvars = _collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
Expand Down Expand Up @@ -480,9 +477,9 @@ def _caller(depth=2):
return None


# A bug in runtime-checkable protocols was fixed in 3.10+,
# but we backport it to all versions
if sys.version_info >= (3, 10):
# The performance of runtime-checkable protocols is significantly improved on Python 3.12,
# so we backport the 3.12 version of Protocol to Python <=3.11
if sys.version_info >= (3, 12):
Protocol = typing.Protocol
runtime_checkable = typing.runtime_checkable
else:
Expand All @@ -500,6 +497,15 @@ def _no_init(self, *args, **kwargs):
class _ProtocolMeta(abc.ABCMeta):
# This metaclass is a bit unfortunate and exists only because of the lack
# of __instancehook__.
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
cls.__protocol_attrs__ = _get_protocol_attrs(cls)
# PEP 544 prohibits using issubclass()
# with protocols that have non-method members.
cls.__callable_proto_members_only__ = all(
callable(getattr(cls, attr, None)) for attr in cls.__protocol_attrs__
)

def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
Expand All @@ -511,17 +517,22 @@ def __instancecheck__(cls, instance):
):
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")
if ((not is_protocol_cls or
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):

if super().__instancecheck__(instance):
return True

if is_protocol_cls:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
for attr in _get_protocol_attrs(cls)):
for attr in cls.__protocol_attrs__:
try:
val = getattr(instance, attr)
except AttributeError:
break
if val is None and callable(getattr(cls, attr, None)):
break
else:
return True
return super().__instancecheck__(instance)

return False

class Protocol(metaclass=_ProtocolMeta):
# There is quite a lot of overlapping code with typing.Generic.
Expand Down Expand Up @@ -613,15 +624,15 @@ def _proto_hook(other):
return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
if not _is_callable_members_only(cls):
if not cls.__callable_proto_members_only__:
if _allow_reckless_class_checks():
return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
if not isinstance(other, type):
# Same error as for issubclass(1, int)
raise TypeError('issubclass() arg 1 must be a class')
for attr in _get_protocol_attrs(cls):
for attr in cls.__protocol_attrs__:
for base in other.__mro__:
if attr in base.__dict__:
if base.__dict__[attr] is None:
Expand Down Expand Up @@ -1819,6 +1830,10 @@ class Movie(TypedDict):

if hasattr(typing, "Unpack"): # 3.11+
Unpack = typing.Unpack

def _is_unpack(obj):
return get_origin(obj) is Unpack

elif sys.version_info[:2] >= (3, 9):
class _UnpackSpecialForm(typing._SpecialForm, _root=True):
def __repr__(self):
Expand Down

0 comments on commit 6c93956

Please sign in to comment.