Skip to content

Commit

Permalink
Merge pull request #22220 from charris/backport-22193
Browse files Browse the repository at this point in the history
BUG: change overloads to play nice with pyright.
  • Loading branch information
charris committed Sep 7, 2022
2 parents e18dd98 + 336f3a4 commit bf46295
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions numpy/_typing/__init__.py
Expand Up @@ -198,6 +198,7 @@ class _8Bit(_16Bit): # type: ignore[misc]
_ArrayLikeVoid_co as _ArrayLikeVoid_co,
_ArrayLikeStr_co as _ArrayLikeStr_co,
_ArrayLikeBytes_co as _ArrayLikeBytes_co,
_ArrayLikeUnknown as _ArrayLikeUnknown,
)
from ._generic_alias import (
NDArray as NDArray,
Expand Down
13 changes: 13 additions & 0 deletions numpy/_typing/_array_like.py
Expand Up @@ -141,3 +141,16 @@ def __array_function__(
"dtype[integer[Any]]",
int,
]

# Extra ArrayLike type so that pyright can deal with NDArray[Any]
# Used as the first overload, should only match NDArray[Any],
# not any actual types.
# https://github.com/numpy/numpy/pull/22193
class _UnknownType:
...


_ArrayLikeUnknown = _DualArrayLike[
"dtype[_UnknownType]",
_UnknownType,
]
34 changes: 34 additions & 0 deletions numpy/core/numeric.pyi
Expand Up @@ -43,6 +43,7 @@ from numpy._typing import (
_ArrayLikeComplex_co,
_ArrayLikeTD64_co,
_ArrayLikeObject_co,
_ArrayLikeUnknown,
)

_T = TypeVar("_T")
Expand Down Expand Up @@ -255,6 +256,12 @@ def argwhere(a: ArrayLike) -> NDArray[intp]: ...

def flatnonzero(a: ArrayLike) -> NDArray[intp]: ...

@overload
def correlate(
a: _ArrayLikeUnknown,
v: _ArrayLikeUnknown,
mode: _CorrelateMode = ...,
) -> NDArray[Any]: ...
@overload
def correlate(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -298,6 +305,12 @@ def correlate(
mode: _CorrelateMode = ...,
) -> NDArray[object_]: ...

@overload
def convolve(
a: _ArrayLikeUnknown,
v: _ArrayLikeUnknown,
mode: _CorrelateMode = ...,
) -> NDArray[Any]: ...
@overload
def convolve(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -341,6 +354,12 @@ def convolve(
mode: _CorrelateMode = ...,
) -> NDArray[object_]: ...

@overload
def outer(
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
out: None = ...,
) -> NDArray[Any]: ...
@overload
def outer(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -390,6 +409,12 @@ def outer(
out: _ArrayType,
) -> _ArrayType: ...

@overload
def tensordot(
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
) -> NDArray[Any]: ...
@overload
def tensordot(
a: _ArrayLikeBool_co,
Expand Down Expand Up @@ -458,6 +483,15 @@ def moveaxis(
destination: _ShapeLike,
) -> NDArray[_SCT]: ...

@overload
def cross(
a: _ArrayLikeUnknown,
b: _ArrayLikeUnknown,
axisa: int = ...,
axisb: int = ...,
axisc: int = ...,
axis: None | int = ...,
) -> NDArray[Any]: ...
@overload
def cross(
a: _ArrayLikeBool_co,
Expand Down

0 comments on commit bf46295

Please sign in to comment.