Skip to content

Commit

Permalink
Merge pull request #19745 from BvB93/astype
Browse files Browse the repository at this point in the history
ENH: Add dtype-support to 3 `generic`/`ndarray` methods
  • Loading branch information
charris committed Aug 24, 2021
2 parents 38be4cb + 8c3d526 commit 75f55a2
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 47 deletions.
125 changes: 106 additions & 19 deletions numpy/__init__.pyi
Expand Up @@ -1254,20 +1254,9 @@ class _ArrayOrScalarCommon:
def __deepcopy__(self: _ArraySelf, __memo: Optional[dict] = ...) -> _ArraySelf: ...
def __eq__(self, other): ...
def __ne__(self, other): ...
def astype(
self: _ArraySelf,
dtype: DTypeLike,
order: _OrderKACF = ...,
casting: _Casting = ...,
subok: bool = ...,
copy: bool = ...,
) -> _ArraySelf: ...
def copy(self: _ArraySelf, order: _OrderKACF = ...) -> _ArraySelf: ...
def dump(self, file: str) -> None: ...
def dumps(self) -> bytes: ...
def getfield(
self: _ArraySelf, dtype: DTypeLike, offset: int = ...
) -> _ArraySelf: ...
def tobytes(self, order: _OrderKACF = ...) -> bytes: ...
# NOTE: `tostring()` is deprecated and therefore excluded
# def tostring(self, order=...): ...
Expand All @@ -1276,14 +1265,6 @@ class _ArrayOrScalarCommon:
) -> None: ...
# generics and 0d arrays return builtin scalars
def tolist(self) -> Any: ...
@overload
def view(self, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def view(self: _ArraySelf, dtype: DTypeLike = ...) -> _ArraySelf: ...
@overload
def view(
self, dtype: DTypeLike, type: Type[_NdArraySubClass]
) -> _NdArraySubClass: ...

# TODO: Add proper signatures
def __getitem__(self, key) -> Any: ...
Expand Down Expand Up @@ -1665,6 +1646,12 @@ _T_co = TypeVar("_T_co", covariant=True)
_2Tuple = Tuple[_T, _T]
_Casting = L["no", "equiv", "safe", "same_kind", "unsafe"]

_DTypeLike = Union[
dtype[_ScalarType],
Type[_ScalarType],
_SupportsDType[dtype[_ScalarType]],
]

_ArrayUInt_co = NDArray[Union[bool_, unsignedinteger[Any]]]
_ArrayInt_co = NDArray[Union[bool_, integer[Any]]]
_ArrayFloat_co = NDArray[Union[bool_, integer[Any], floating[Any]]]
Expand Down Expand Up @@ -1914,6 +1901,53 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
self, *shape: SupportsIndex, order: _OrderACF = ...
) -> ndarray[Any, _DType_co]: ...

@overload
def astype(
self,
dtype: _DTypeLike[_ScalarType],
order: _OrderKACF = ...,
casting: _Casting = ...,
subok: bool = ...,
copy: bool = ...,
) -> NDArray[_ScalarType]: ...
@overload
def astype(
self,
dtype: DTypeLike,
order: _OrderKACF = ...,
casting: _Casting = ...,
subok: bool = ...,
copy: bool = ...,
) -> NDArray[Any]: ...

@overload
def view(self: _ArraySelf) -> _ArraySelf: ...
@overload
def view(self, type: Type[_NdArraySubClass]) -> _NdArraySubClass: ...
@overload
def view(self, dtype: _DTypeLike[_ScalarType]) -> NDArray[_ScalarType]: ...
@overload
def view(self, dtype: DTypeLike) -> NDArray[Any]: ...
@overload
def view(
self,
dtype: DTypeLike,
type: Type[_NdArraySubClass],
) -> _NdArraySubClass: ...

@overload
def getfield(
self,
dtype: _DTypeLike[_ScalarType],
offset: SupportsIndex = ...
) -> NDArray[_ScalarType]: ...
@overload
def getfield(
self,
dtype: DTypeLike,
offset: SupportsIndex = ...
) -> NDArray[Any]: ...

# Dispatch to the underlying `generic` via protocols
def __int__(
self: ndarray[Any, dtype[SupportsInt]], # type: ignore[type-var]
Expand Down Expand Up @@ -2886,6 +2920,59 @@ class generic(_ArrayOrScalarCommon):
def byteswap(self: _ScalarType, inplace: L[False] = ...) -> _ScalarType: ...
@property
def flat(self: _ScalarType) -> flatiter[ndarray[Any, dtype[_ScalarType]]]: ...

@overload
def astype(
self,
dtype: _DTypeLike[_ScalarType],
order: _OrderKACF = ...,
casting: _Casting = ...,
subok: bool = ...,
copy: bool = ...,
) -> _ScalarType: ...
@overload
def astype(
self,
dtype: DTypeLike,
order: _OrderKACF = ...,
casting: _Casting = ...,
subok: bool = ...,
copy: bool = ...,
) -> Any: ...

# NOTE: `view` will perform a 0D->scalar cast,
# thus the array `type` is irrelevant to the output type
@overload
def view(
self: _ScalarType,
type: Type[ndarray[Any, Any]] = ...,
) -> _ScalarType: ...
@overload
def view(
self,
dtype: _DTypeLike[_ScalarType],
type: Type[ndarray[Any, Any]] = ...,
) -> _ScalarType: ...
@overload
def view(
self,
dtype: DTypeLike,
type: Type[ndarray[Any, Any]] = ...,
) -> Any: ...

@overload
def getfield(
self,
dtype: _DTypeLike[_ScalarType],
offset: SupportsIndex = ...
) -> _ScalarType: ...
@overload
def getfield(
self,
dtype: DTypeLike,
offset: SupportsIndex = ...
) -> Any: ...

def item(
self,
__args: Union[L[0], Tuple[()], Tuple[L[0]]] = ...,
Expand Down
53 changes: 25 additions & 28 deletions numpy/typing/tests/data/reveal/ndarray_conversion.py
@@ -1,12 +1,13 @@
import numpy as np
import numpy.typing as npt

nd = np.array([[1, 2], [3, 4]])
nd: npt.NDArray[np.int_] = np.array([[1, 2], [3, 4]])

# item
reveal_type(nd.item()) # E: Any
reveal_type(nd.item(1)) # E: Any
reveal_type(nd.item(0, 1)) # E: Any
reveal_type(nd.item((0, 1))) # E: Any
reveal_type(nd.item()) # E: int
reveal_type(nd.item(1)) # E: int
reveal_type(nd.item(0, 1)) # E: int
reveal_type(nd.item((0, 1))) # E: int

# tolist
reveal_type(nd.tolist()) # E: Any
Expand All @@ -19,36 +20,32 @@
# dumps is pretty simple

# astype
reveal_type(nd.astype("float")) # E: numpy.ndarray
reveal_type(nd.astype(float)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe")) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True)) # E: numpy.ndarray
reveal_type(nd.astype(float, "K", "unsafe", True, True)) # E: numpy.ndarray
reveal_type(nd.astype("float")) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(nd.astype(float)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(nd.astype(np.float64)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.astype(np.float64, "K")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.astype(np.float64, "K", "unsafe")) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.astype(np.float64, "K", "unsafe", True)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.astype(np.float64, "K", "unsafe", True, True)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]

# byteswap
reveal_type(nd.byteswap()) # E: numpy.ndarray
reveal_type(nd.byteswap(True)) # E: numpy.ndarray
reveal_type(nd.byteswap()) # E: numpy.ndarray[Any, numpy.dtype[{int_}]]
reveal_type(nd.byteswap(True)) # E: numpy.ndarray[Any, numpy.dtype[{int_}]]

# copy
reveal_type(nd.copy()) # E: numpy.ndarray
reveal_type(nd.copy("C")) # E: numpy.ndarray
reveal_type(nd.copy()) # E: numpy.ndarray[Any, numpy.dtype[{int_}]]
reveal_type(nd.copy("C")) # E: numpy.ndarray[Any, numpy.dtype[{int_}]]

# view
class SubArray(np.ndarray):
pass


reveal_type(nd.view()) # E: numpy.ndarray
reveal_type(nd.view(np.int64)) # E: numpy.ndarray
# replace `Any` with `numpy.matrix` when `matrix` will be added to stubs
reveal_type(nd.view(np.int64, np.matrix)) # E: Any
reveal_type(nd.view(np.int64, SubArray)) # E: SubArray
reveal_type(nd.view()) # E: numpy.ndarray[Any, numpy.dtype[{int_}]]
reveal_type(nd.view(np.float64)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.view(float)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(nd.view(np.float64, np.matrix)) # E: numpy.matrix[Any, Any]

# getfield
reveal_type(nd.getfield("float")) # E: numpy.ndarray
reveal_type(nd.getfield(float)) # E: numpy.ndarray
reveal_type(nd.getfield(float, 8)) # E: numpy.ndarray
reveal_type(nd.getfield("float")) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(nd.getfield(float)) # E: numpy.ndarray[Any, numpy.dtype[Any]]
reveal_type(nd.getfield(np.float64)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]
reveal_type(nd.getfield(np.float64, 8)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]]

# setflags does not return a value
# fill does not return a value
12 changes: 12 additions & 0 deletions numpy/typing/tests/data/reveal/scalars.py
Expand Up @@ -144,3 +144,15 @@
if sys.version_info >= (3, 9):
reveal_type(f8.__ceil__()) # E: int
reveal_type(f8.__floor__()) # E: int

reveal_type(i8.astype(float)) # E: Any
reveal_type(i8.astype(np.float64)) # E: {float64}

reveal_type(i8.view()) # E: {int64}
reveal_type(i8.view(np.float64)) # E: {float64}
reveal_type(i8.view(float)) # E: Any
reveal_type(i8.view(np.float64, np.ndarray)) # E: {float64}

reveal_type(i8.getfield(float)) # E: Any
reveal_type(i8.getfield(np.float64)) # E: {float64}
reveal_type(i8.getfield(np.float64, 8)) # E: {float64}

0 comments on commit 75f55a2

Please sign in to comment.