diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index 9f57b22956cc..cfd9aacb4927 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -77,7 +77,7 @@ def __array__(self) -> ndarray[Any, _DType_co]: ... ArrayLike = Union[ _RecursiveSequence, _ArrayLike[ - "dtype[Any]", + dtype, Union[bool, int, float, complex, str, bytes] ], ] diff --git a/numpy/typing/_dtype_like.py b/numpy/typing/_dtype_like.py index a41e2f358d97..636e2209b45f 100644 --- a/numpy/typing/_dtype_like.py +++ b/numpy/typing/_dtype_like.py @@ -1,5 +1,15 @@ import sys -from typing import Any, List, Sequence, Tuple, Union, Type, TypeVar, TYPE_CHECKING +from typing import ( + Any, + List, + Sequence, + Tuple, + Union, + Type, + TypeVar, + Generic, + TYPE_CHECKING, +) import numpy as np from ._shape import _ShapeLike @@ -81,7 +91,9 @@ def dtype(self) -> _DType_co: ... else: _DTypeDict = Any - _SupportsDType = Any + + class _SupportsDType(Generic[_DType_co]): + pass # Would create a dtype[np.void] @@ -112,7 +124,7 @@ def dtype(self) -> _DType_co: ... # array-scalar types and generic types type, # TODO: enumerate these when we add type hints for numpy scalars # anything with a dtype attribute - "_SupportsDType[np.dtype[Any]]", + _SupportsDType[np.dtype], # character codes, type strings or comma-separated fields, e.g., 'float64' str, _VoidDTypeLike, diff --git a/numpy/typing/tests/test_runtime.py b/numpy/typing/tests/test_runtime.py new file mode 100644 index 000000000000..e82b08ac26a0 --- /dev/null +++ b/numpy/typing/tests/test_runtime.py @@ -0,0 +1,90 @@ +"""Test the runtime usage of `numpy.typing`.""" + +from __future__ import annotations + +import sys +from typing import get_type_hints, Union, Tuple, NamedTuple + +import pytest +import numpy as np +import numpy.typing as npt + +try: + from typing_extensions import get_args, get_origin + SKIP = False +except ImportError: + SKIP = True + + +class TypeTup(NamedTuple): + typ: type + args: Tuple[type, ...] + origin: None | type + + +if sys.version_info >= (3, 9): + NDArrayTup = TypeTup(npt.NDArray, npt.NDArray.__args__, np.ndarray) +else: + NDArrayTup = TypeTup(npt.NDArray, (), None) + +TYPES = { + "ArrayLike": TypeTup(npt.ArrayLike, npt.ArrayLike.__args__, Union), + "DTypeLike": TypeTup(npt.DTypeLike, npt.DTypeLike.__args__, Union), + "NBitBase": TypeTup(npt.NBitBase, (), None), + "NDArray": NDArrayTup, +} + + +@pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys()) +@pytest.mark.skipif(SKIP, reason="requires typing-extensions") +def test_get_args(name: type, tup: TypeTup) -> None: + """Test `typing.get_args`.""" + typ, ref = tup.typ, tup.args + out = get_args(typ) + assert out == ref + + +@pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys()) +@pytest.mark.skipif(SKIP, reason="requires typing-extensions") +def test_get_origin(name: type, tup: TypeTup) -> None: + """Test `typing.get_origin`.""" + typ, ref = tup.typ, tup.origin + out = get_origin(typ) + assert out == ref + + +@pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys()) +def test_get_type_hints(name: type, tup: TypeTup) -> None: + """Test `typing.get_type_hints`.""" + typ = tup.typ + + # Explicitly set `__annotations__` in order to circumvent the + # stringification performed by `from __future__ import annotations` + def func(a): pass + func.__annotations__ = {"a": typ, "return": None} + + out = get_type_hints(func) + ref = {"a": typ, "return": type(None)} + assert out == ref + + +@pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys()) +def test_get_type_hints_str(name: type, tup: TypeTup) -> None: + """Test `typing.get_type_hints` with string-representation of types.""" + typ_str, typ = f"npt.{name}", tup.typ + + # Explicitly set `__annotations__` in order to circumvent the + # stringification performed by `from __future__ import annotations` + def func(a): pass + func.__annotations__ = {"a": typ_str, "return": None} + + out = get_type_hints(func) + ref = {"a": typ, "return": type(None)} + assert out == ref + + +def test_keys() -> None: + """Test that ``TYPES.keys()`` and ``numpy.typing.__all__`` are synced.""" + keys = TYPES.keys() + ref = set(npt.__all__) + assert keys == ref