diff --git a/doc/release/upcoming_changes/22357.improvement.rst b/doc/release/upcoming_changes/22357.improvement.rst new file mode 100644 index 000000000000..a0332eaa07ba --- /dev/null +++ b/doc/release/upcoming_changes/22357.improvement.rst @@ -0,0 +1,6 @@ +``numpy.typing`` protocols are now runtime checkable +---------------------------------------------------- + +The protocols used in `~numpy.typing.ArrayLike` and `~numpy.typing.DTypeLike` +are now properly marked as runtime checkable, making them easier to use for +runtime type checkers. diff --git a/numpy/_typing/_array_like.py b/numpy/_typing/_array_like.py index 2e5684b0bfb6..67d67ce19c31 100644 --- a/numpy/_typing/_array_like.py +++ b/numpy/_typing/_array_like.py @@ -3,7 +3,7 @@ # NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias, # not an annotation from collections.abc import Collection, Callable -from typing import Any, Sequence, Protocol, Union, TypeVar +from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable from numpy import ( ndarray, dtype, @@ -33,10 +33,12 @@ # array. # Concrete implementations of the protocol are responsible for adding # any and all remaining overloads +@runtime_checkable class _SupportsArray(Protocol[_DType_co]): def __array__(self) -> ndarray[Any, _DType_co]: ... +@runtime_checkable class _SupportsArrayFunc(Protocol): """A protocol class representing `~class.__array_function__`.""" def __array_function__( @@ -146,7 +148,7 @@ def __array_function__( # Used as the first overload, should only match NDArray[Any], # not any actual types. # https://github.com/numpy/numpy/pull/22193 -class _UnknownType: +class _UnknownType: ... diff --git a/numpy/_typing/_dtype_like.py b/numpy/_typing/_dtype_like.py index b705d82fdb07..e92e17dd20cf 100644 --- a/numpy/_typing/_dtype_like.py +++ b/numpy/_typing/_dtype_like.py @@ -8,6 +8,7 @@ TypeVar, Protocol, TypedDict, + runtime_checkable, ) import numpy as np @@ -80,6 +81,7 @@ class _DTypeDict(_DTypeDictBase, total=False): # A protocol for anything with the dtype attribute +@runtime_checkable class _SupportsDType(Protocol[_DType_co]): @property def dtype(self) -> _DType_co: ... diff --git a/numpy/_typing/_nested_sequence.py b/numpy/_typing/_nested_sequence.py index 7c12c4a873db..6e9dded44530 100644 --- a/numpy/_typing/_nested_sequence.py +++ b/numpy/_typing/_nested_sequence.py @@ -8,6 +8,7 @@ overload, TypeVar, Protocol, + runtime_checkable, ) __all__ = ["_NestedSequence"] @@ -15,6 +16,7 @@ _T_co = TypeVar("_T_co", covariant=True) +@runtime_checkable class _NestedSequence(Protocol[_T_co]): """A protocol for representing nested sequences. diff --git a/numpy/typing/tests/test_runtime.py b/numpy/typing/tests/test_runtime.py index 5b5df49dc571..44d069006320 100644 --- a/numpy/typing/tests/test_runtime.py +++ b/numpy/typing/tests/test_runtime.py @@ -3,11 +3,19 @@ from __future__ import annotations import sys -from typing import get_type_hints, Union, NamedTuple, get_args, get_origin +from typing import ( + get_type_hints, + Union, + NamedTuple, + get_args, + get_origin, + Any, +) import pytest import numpy as np import numpy.typing as npt +import numpy._typing as _npt class TypeTup(NamedTuple): @@ -80,3 +88,26 @@ def test_keys() -> None: keys = TYPES.keys() ref = set(npt.__all__) assert keys == ref + + +PROTOCOLS: dict[str, tuple[type[Any], object]] = { + "_SupportsDType": (_npt._SupportsDType, np.int64(1)), + "_SupportsArray": (_npt._SupportsArray, np.arange(10)), + "_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)), + "_NestedSequence": (_npt._NestedSequence, [1]), +} + + +@pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys()) +class TestRuntimeProtocol: + def test_isinstance(self, cls: type[Any], obj: object) -> None: + assert isinstance(obj, cls) + assert not isinstance(None, cls) + + def test_issubclass(self, cls: type[Any], obj: object) -> None: + if cls is _npt._SupportsDType: + pytest.xfail( + "Protocols with non-method members don't support issubclass()" + ) + assert issubclass(type(obj), cls) + assert not issubclass(type(None), cls)