Skip to content

Commit

Permalink
BUG: Introduce Unknown array type to deal with NDArray[Any] more grac…
Browse files Browse the repository at this point in the history
…efully.
  • Loading branch information
iantra authored and charris committed Sep 7, 2022
1 parent 5a5b650 commit 336f3a4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
1 change: 1 addition & 0 deletions numpy/_typing/__init__.py
Expand Up @@ -198,6 +198,7 @@ class _8Bit(_16Bit): # type: ignore[misc]
_ArrayLikeVoid_co as _ArrayLikeVoid_co,
_ArrayLikeStr_co as _ArrayLikeStr_co,
_ArrayLikeBytes_co as _ArrayLikeBytes_co,
_ArrayLikeUnknown as _ArrayLikeUnknown,
)
from ._generic_alias import (
NDArray as NDArray,
Expand Down
13 changes: 13 additions & 0 deletions numpy/_typing/_array_like.py
Expand Up @@ -141,3 +141,16 @@ def __array_function__(
"dtype[integer[Any]]",
int,
]

# Extra ArrayLike type so that pyright can deal with NDArray[Any]
# Used as the first overload, should only match NDArray[Any],
# not any actual types.
# https://github.com/numpy/numpy/pull/22193
class _UnknownType:
...


_ArrayLikeUnknown = _DualArrayLike[
"dtype[_UnknownType]",
_UnknownType,
]
62 changes: 48 additions & 14 deletions numpy/core/numeric.pyi
Expand Up @@ -43,6 +43,7 @@ from numpy._typing import (
_ArrayLikeComplex_co,
_ArrayLikeTD64_co,
_ArrayLikeObject_co,
_ArrayLikeUnknown,
)

_T = TypeVar("_T")
Expand Down Expand Up @@ -257,10 +258,10 @@ def flatnonzero(a: ArrayLike) -> NDArray[intp]: ...

@overload
def correlate(
a: _ArrayLikeObject_co,
v: _ArrayLikeObject_co,
a: _ArrayLikeUnknown,
v: _ArrayLikeUnknown,
mode: _CorrelateMode = ...,
) -> NDArray[object_]: ...
) -> NDArray[Any]: ...
@overload
def correlate(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -297,13 +298,19 @@ def correlate(
v: _ArrayLikeTD64_co,
mode: _CorrelateMode = ...,
) -> NDArray[timedelta64]: ...

@overload
def convolve(
def correlate(
a: _ArrayLikeObject_co,
v: _ArrayLikeObject_co,
mode: _CorrelateMode = ...,
) -> NDArray[object_]: ...

@overload
def convolve(
a: _ArrayLikeUnknown,
v: _ArrayLikeUnknown,
mode: _CorrelateMode = ...,
) -> NDArray[Any]: ...
@overload
def convolve(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -340,13 +347,19 @@ def convolve(
v: _ArrayLikeTD64_co,
mode: _CorrelateMode = ...,
) -> NDArray[timedelta64]: ...
@overload
def convolve(
a: _ArrayLikeObject_co,
v: _ArrayLikeObject_co,
mode: _CorrelateMode = ...,
) -> NDArray[object_]: ...

@overload
def outer(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
out: None = ...,
) -> NDArray[object_]: ...
) -> NDArray[Any]: ...
@overload
def outer(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -384,6 +397,12 @@ def outer(
out: None = ...,
) -> NDArray[timedelta64]: ...
@overload
def outer(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
out: None = ...,
) -> NDArray[object_]: ...
@overload
def outer(
a: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
b: _ArrayLikeComplex_co | _ArrayLikeTD64_co | _ArrayLikeObject_co,
Expand All @@ -392,10 +411,10 @@ def outer(

@overload
def tensordot(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[object_]: ...
) -> NDArray[Any]: ...
@overload
def tensordot(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -432,6 +451,12 @@ def tensordot(
b: _ArrayLikeTD64_co,
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[timedelta64]: ...
@overload
def tensordot(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[object_]: ...

@overload
def roll(
Expand Down Expand Up @@ -460,13 +485,13 @@ def moveaxis(

@overload
def cross(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[object_]: ...
) -> NDArray[Any]: ...
@overload
def cross(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -512,6 +537,15 @@ def cross(
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def cross(
a: _ArrayLikeObject_co,
b: _ArrayLikeObject_co,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[object_]: ...

@overload
def indices(
Expand Down

0 comments on commit 336f3a4

Please sign in to comment.