Skip to content

Commit

Permalink
TST: Remove TypeVarTuple from tests
Browse files Browse the repository at this point in the history
They cannot be packed into _ShapeType because they are not bound
to ints.

Also make docs more explicit about the correct way to use npt.Array and
when to use npt.NDArray
  • Loading branch information
Jacob-Stevens-Haas committed May 6, 2024
1 parent b0d838f commit 8f8a973
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
29 changes: 25 additions & 4 deletions doc/release/upcoming_changes/26081.improvement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,36 @@ shape by a tuple of ``int`` subtypes, which are reflected in the type of
reveal_type(transpose(arr).shape) # tuple[Series, Time]
This PR also provides the ``np.typing.Array`` type alias in the style of
``NDArray``. It allows specifying a shaped array with simpler dtype syntax.
By relying on variadic type parameters, it also allows typing functions that
handle an unknown number of axes:
``NDArray``. It allows specifying a shaped array with simpler dtype syntax:

.. code::
from typing import Literal, TypeVar, TypeVarTuple
import numpy as np
import numpy.typing as npt
Float = np.floating[npt.NBitBase]
DType = TypeVar("DType", bound=np.generic)
Float1D = npt.Array[int, Float]
Float2D = npt.Array[int, int, Float]
def stack_1Ds(*args: Float1D) -> Float2D:
return np.stack(args)
arr = np.arange(12, dtype=np.uint16)
stacked = stack_1Ds(arr, arr, arr)
reveal_type(stacked) #
Note that, if shape is unknown or arbitrary, it is still recommended to use
``NDArray``. Because ``TypeVarTuple`` cannot be bound, not all type checkers
allow arbitrarily shape arrays. e.g.:

.. code::
from typing import TypeVarTuple
Shapes = TypeVarTuple("Shapes")
def stack(
def stack_two_arrays(
a: npt.Array[*Shapes, DType],
b: npt.Array[*Shapes, DType],
) -> npt.Array[Literal[2], *Shapes, DType]: ...
Expand All @@ -56,3 +72,8 @@ handle an unknown number of axes:
double_arr: npt.Array[Literal[2], Literal[3], Literal[4], np.uint16]
double_arr = stack(arr, arr)
In mypy, this generates the error:
``Type argument "tuple[*Shapes]" of "ndarray" must be a subtype of
"tuple[int, ...]" [type-var]``
3 changes: 2 additions & 1 deletion numpy/_typing/_add_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _parse_docstrings() -> str:
Can be used during runtime for typing arrays with a given dtype
and specified shape. Particularly useful for functions which modify number
of axes. If it is important to enforce integer axis sizes, use np.ndarray
for typing.
for typing. If axes dimensions are variable, it is best to use
`npt.NDArray <numpy.typing.NDArray>`.
.. versionadded:: 2.1
Expand Down
39 changes: 24 additions & 15 deletions numpy/typing/tests/data/pass/shape311.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,34 @@

import numpy as np
import numpy.typing as npt
from typing_extensions import assert_type
from typing_extensions import assert_type, cast, reveal_type

if sys.version_info >= (3, 11):
DType = TypeVar("DType", bound=np.generic)
Shapes = TypeVarTuple("Shapes")

def stack(
a: npt.Array[*Shapes, DType],
b: npt.Array[*Shapes, DType],
) -> npt.Array[Literal[2], *Shapes, DType]:
return np.stack((a, b))
# Check that typevartuple in alias is packed correctly
Length = NewType("Length", int)
Width = NewType("Width", int)
arr: npt.Array[Length, Width, np.int8] = np.array([[0]])
assert_type(arr, np.ndarray[tuple[Length, Width], np.dtype[np.int8]])

arr: npt.Array[Literal[3], Literal[4], np.uint16]
arr = np.arange(12, dtype=np.uint16).reshape((3, 4))
# Check that typevartuple in alias is unpacked correctly
M = TypeVar("M", bound=int)
N = TypeVar("N", bound=int)
T = TypeVar("T", bound=np.generic)

double_arr = stack(arr, arr)
assert_type(double_arr, npt.Array[Literal[2], Literal[3], Literal[4], np.uint16])

Length = NewType("Length", int)
Width = NewType("Width", int)
arr2: npt.Array[Length, Width, np.int8] = np.array([[0]])
assert_type(arr2, npt.Array[Length, Width, np.int8])
def mult(vec: npt.Array[N, T], mat: npt.Array[M, N, T]) -> npt.Array[M, T]:
return mat @ vec # type: ignore


arr2: np.ndarray[tuple[Width], np.dtype[np.int8]] = np.array([0])
assert_type(mult(arr2, arr), np.ndarray[tuple[Length], np.dtype[np.int8]])


# Check that shape works
def return_shp(a: npt.Array[M, N, DType]) -> tuple[M, N]:
return a.shape

shp = return_shp(arr)
assert_type(shp, tuple[Length, Width])

0 comments on commit 8f8a973

Please sign in to comment.