Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP,ENH: Mark numpy.typing protocols as runtime checkable #22388

Merged
merged 1 commit into from Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/release/upcoming_changes/22357.improvement.rst
@@ -0,0 +1,6 @@
``numpy.typing`` protocols are now runtime checkable
----------------------------------------------------

The protocols used in `~numpy.typing.ArrayLike` and `~numpy.typing.DTypeLike`
are now properly marked as runtime checkable, making them easier to use for
runtime type checkers.
6 changes: 4 additions & 2 deletions numpy/_typing/_array_like.py
Expand Up @@ -3,7 +3,7 @@
# NOTE: Import `Sequence` from `typing` as we it is needed for a type-alias,
# not an annotation
from collections.abc import Collection, Callable
from typing import Any, Sequence, Protocol, Union, TypeVar
from typing import Any, Sequence, Protocol, Union, TypeVar, runtime_checkable
from numpy import (
ndarray,
dtype,
Expand Down Expand Up @@ -33,10 +33,12 @@
# array.
# Concrete implementations of the protocol are responsible for adding
# any and all remaining overloads
@runtime_checkable
class _SupportsArray(Protocol[_DType_co]):
def __array__(self) -> ndarray[Any, _DType_co]: ...


@runtime_checkable
class _SupportsArrayFunc(Protocol):
"""A protocol class representing `~class.__array_function__`."""
def __array_function__(
Expand Down Expand Up @@ -146,7 +148,7 @@ def __array_function__(
# Used as the first overload, should only match NDArray[Any],
# not any actual types.
# https://github.com/numpy/numpy/pull/22193
class _UnknownType:
class _UnknownType:
...


Expand Down
2 changes: 2 additions & 0 deletions numpy/_typing/_dtype_like.py
Expand Up @@ -8,6 +8,7 @@
TypeVar,
Protocol,
TypedDict,
runtime_checkable,
)

import numpy as np
Expand Down Expand Up @@ -80,6 +81,7 @@ class _DTypeDict(_DTypeDictBase, total=False):


# A protocol for anything with the dtype attribute
@runtime_checkable
class _SupportsDType(Protocol[_DType_co]):
@property
def dtype(self) -> _DType_co: ...
Expand Down
2 changes: 2 additions & 0 deletions numpy/_typing/_nested_sequence.py
Expand Up @@ -8,13 +8,15 @@
overload,
TypeVar,
Protocol,
runtime_checkable,
)

__all__ = ["_NestedSequence"]

_T_co = TypeVar("_T_co", covariant=True)


@runtime_checkable
class _NestedSequence(Protocol[_T_co]):
"""A protocol for representing nested sequences.

Expand Down
33 changes: 32 additions & 1 deletion numpy/typing/tests/test_runtime.py
Expand Up @@ -3,11 +3,19 @@
from __future__ import annotations

import sys
from typing import get_type_hints, Union, NamedTuple, get_args, get_origin
from typing import (
get_type_hints,
Union,
NamedTuple,
get_args,
get_origin,
Any,
)

import pytest
import numpy as np
import numpy.typing as npt
import numpy._typing as _npt


class TypeTup(NamedTuple):
Expand Down Expand Up @@ -80,3 +88,26 @@ def test_keys() -> None:
keys = TYPES.keys()
ref = set(npt.__all__)
assert keys == ref


PROTOCOLS: dict[str, tuple[type[Any], object]] = {
"_SupportsDType": (_npt._SupportsDType, np.int64(1)),
"_SupportsArray": (_npt._SupportsArray, np.arange(10)),
"_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
"_NestedSequence": (_npt._NestedSequence, [1]),
}


@pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
class TestRuntimeProtocol:
def test_isinstance(self, cls: type[Any], obj: object) -> None:
assert isinstance(obj, cls)
assert not isinstance(None, cls)

def test_issubclass(self, cls: type[Any], obj: object) -> None:
if cls is _npt._SupportsDType:
pytest.xfail(
"Protocols with non-method members don't support issubclass()"
)
assert issubclass(type(obj), cls)
assert not issubclass(type(None), cls)