Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Add missing dtype overloads for object- and ctypes-based inputs #19503

Merged
merged 4 commits into from Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 46 additions & 32 deletions numpy/__init__.pyi
@@ -1,6 +1,9 @@
import builtins
import os
import sys
import mmap
import ctypes as ct
import array as _array
import datetime as dt
from abc import abstractmethod
from types import TracebackType
Expand Down Expand Up @@ -920,7 +923,7 @@ class dtype(Generic[_DTypeScalar_co]):
# other special cases. Order is sometimes important because of the
# subtype relationships
#
# bool < int < float < complex
# bool < int < float < complex < object
#
# so we have to make sure the overloads for the narrowest type is
# first.
Expand All @@ -938,51 +941,54 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: Type[bytes], align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...

# `unsignedinteger` string-based representations
# `unsignedinteger` string-based representations and ctypes
@overload
def __new__(cls, dtype: _UInt8Codes, align: bool = ..., copy: bool = ...) -> dtype[uint8]: ...
def __new__(cls, dtype: _UInt8Codes | Type[ct.c_uint8], align: bool = ..., copy: bool = ...) -> dtype[uint8]: ...
@overload
def __new__(cls, dtype: _UInt16Codes, align: bool = ..., copy: bool = ...) -> dtype[uint16]: ...
def __new__(cls, dtype: _UInt16Codes | Type[ct.c_uint16], align: bool = ..., copy: bool = ...) -> dtype[uint16]: ...
@overload
def __new__(cls, dtype: _UInt32Codes, align: bool = ..., copy: bool = ...) -> dtype[uint32]: ...
def __new__(cls, dtype: _UInt32Codes | Type[ct.c_uint32], align: bool = ..., copy: bool = ...) -> dtype[uint32]: ...
@overload
def __new__(cls, dtype: _UInt64Codes, align: bool = ..., copy: bool = ...) -> dtype[uint64]: ...
def __new__(cls, dtype: _UInt64Codes | Type[ct.c_uint64], align: bool = ..., copy: bool = ...) -> dtype[uint64]: ...
@overload
def __new__(cls, dtype: _UByteCodes, align: bool = ..., copy: bool = ...) -> dtype[ubyte]: ...
def __new__(cls, dtype: _UByteCodes | Type[ct.c_ubyte], align: bool = ..., copy: bool = ...) -> dtype[ubyte]: ...
@overload
def __new__(cls, dtype: _UShortCodes, align: bool = ..., copy: bool = ...) -> dtype[ushort]: ...
def __new__(cls, dtype: _UShortCodes | Type[ct.c_ushort], align: bool = ..., copy: bool = ...) -> dtype[ushort]: ...
@overload
def __new__(cls, dtype: _UIntCCodes, align: bool = ..., copy: bool = ...) -> dtype[uintc]: ...
def __new__(cls, dtype: _UIntCCodes | Type[ct.c_uint], align: bool = ..., copy: bool = ...) -> dtype[uintc]: ...

# NOTE: We're assuming here that `uint_ptr_t == size_t`,
# an assumption that does not hold in rare cases (same for `ssize_t`)
@overload
def __new__(cls, dtype: _UIntPCodes, align: bool = ..., copy: bool = ...) -> dtype[uintp]: ...
def __new__(cls, dtype: _UIntPCodes | Type[ct.c_void_p] | Type[ct.c_size_t], align: bool = ..., copy: bool = ...) -> dtype[uintp]: ...
@overload
def __new__(cls, dtype: _UIntCodes, align: bool = ..., copy: bool = ...) -> dtype[uint]: ...
def __new__(cls, dtype: _UIntCodes | Type[ct.c_ulong], align: bool = ..., copy: bool = ...) -> dtype[uint]: ...
@overload
def __new__(cls, dtype: _ULongLongCodes, align: bool = ..., copy: bool = ...) -> dtype[ulonglong]: ...
def __new__(cls, dtype: _ULongLongCodes | Type[ct.c_ulonglong], align: bool = ..., copy: bool = ...) -> dtype[ulonglong]: ...

# `signedinteger` string-based representations
# `signedinteger` string-based representations and ctypes
@overload
def __new__(cls, dtype: _Int8Codes, align: bool = ..., copy: bool = ...) -> dtype[int8]: ...
def __new__(cls, dtype: _Int8Codes | Type[ct.c_int8], align: bool = ..., copy: bool = ...) -> dtype[int8]: ...
@overload
def __new__(cls, dtype: _Int16Codes, align: bool = ..., copy: bool = ...) -> dtype[int16]: ...
def __new__(cls, dtype: _Int16Codes | Type[ct.c_int16], align: bool = ..., copy: bool = ...) -> dtype[int16]: ...
@overload
def __new__(cls, dtype: _Int32Codes, align: bool = ..., copy: bool = ...) -> dtype[int32]: ...
def __new__(cls, dtype: _Int32Codes | Type[ct.c_int32], align: bool = ..., copy: bool = ...) -> dtype[int32]: ...
@overload
def __new__(cls, dtype: _Int64Codes, align: bool = ..., copy: bool = ...) -> dtype[int64]: ...
def __new__(cls, dtype: _Int64Codes | Type[ct.c_int64], align: bool = ..., copy: bool = ...) -> dtype[int64]: ...
@overload
def __new__(cls, dtype: _ByteCodes, align: bool = ..., copy: bool = ...) -> dtype[byte]: ...
def __new__(cls, dtype: _ByteCodes | Type[ct.c_byte], align: bool = ..., copy: bool = ...) -> dtype[byte]: ...
@overload
def __new__(cls, dtype: _ShortCodes, align: bool = ..., copy: bool = ...) -> dtype[short]: ...
def __new__(cls, dtype: _ShortCodes | Type[ct.c_short], align: bool = ..., copy: bool = ...) -> dtype[short]: ...
@overload
def __new__(cls, dtype: _IntCCodes, align: bool = ..., copy: bool = ...) -> dtype[intc]: ...
def __new__(cls, dtype: _IntCCodes | Type[ct.c_int], align: bool = ..., copy: bool = ...) -> dtype[intc]: ...
@overload
def __new__(cls, dtype: _IntPCodes, align: bool = ..., copy: bool = ...) -> dtype[intp]: ...
def __new__(cls, dtype: _IntPCodes | Type[ct.c_ssize_t], align: bool = ..., copy: bool = ...) -> dtype[intp]: ...
@overload
def __new__(cls, dtype: _IntCodes, align: bool = ..., copy: bool = ...) -> dtype[int_]: ...
def __new__(cls, dtype: _IntCodes | Type[ct.c_long], align: bool = ..., copy: bool = ...) -> dtype[int_]: ...
@overload
def __new__(cls, dtype: _LongLongCodes, align: bool = ..., copy: bool = ...) -> dtype[longlong]: ...
def __new__(cls, dtype: _LongLongCodes | Type[ct.c_longlong], align: bool = ..., copy: bool = ...) -> dtype[longlong]: ...

# `floating` string-based representations
# `floating` string-based representations and ctypes
@overload
def __new__(cls, dtype: _Float16Codes, align: bool = ..., copy: bool = ...) -> dtype[float16]: ...
@overload
Expand All @@ -992,11 +998,11 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: _HalfCodes, align: bool = ..., copy: bool = ...) -> dtype[half]: ...
@overload
def __new__(cls, dtype: _SingleCodes, align: bool = ..., copy: bool = ...) -> dtype[single]: ...
def __new__(cls, dtype: _SingleCodes | Type[ct.c_float], align: bool = ..., copy: bool = ...) -> dtype[single]: ...
@overload
def __new__(cls, dtype: _DoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[double]: ...
def __new__(cls, dtype: _DoubleCodes | Type[ct.c_double], align: bool = ..., copy: bool = ...) -> dtype[double]: ...
@overload
def __new__(cls, dtype: _LongDoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[longdouble]: ...
def __new__(cls, dtype: _LongDoubleCodes | Type[ct.c_longdouble], align: bool = ..., copy: bool = ...) -> dtype[longdouble]: ...

# `complexfloating` string-based representations
@overload
Expand All @@ -1010,21 +1016,21 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: _CLongDoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[clongdouble]: ...

# Miscellaneous string-based representations
# Miscellaneous string-based representations and ctypes
@overload
def __new__(cls, dtype: _BoolCodes, align: bool = ..., copy: bool = ...) -> dtype[bool_]: ...
def __new__(cls, dtype: _BoolCodes | Type[ct.c_bool], align: bool = ..., copy: bool = ...) -> dtype[bool_]: ...
@overload
def __new__(cls, dtype: _TD64Codes, align: bool = ..., copy: bool = ...) -> dtype[timedelta64]: ...
@overload
def __new__(cls, dtype: _DT64Codes, align: bool = ..., copy: bool = ...) -> dtype[datetime64]: ...
@overload
def __new__(cls, dtype: _StrCodes, align: bool = ..., copy: bool = ...) -> dtype[str_]: ...
@overload
def __new__(cls, dtype: _BytesCodes, align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...
def __new__(cls, dtype: _BytesCodes | Type[ct.c_char], align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...
@overload
def __new__(cls, dtype: _VoidCodes, align: bool = ..., copy: bool = ...) -> dtype[void]: ...
@overload
def __new__(cls, dtype: _ObjectCodes, align: bool = ..., copy: bool = ...) -> dtype[object_]: ...
def __new__(cls, dtype: _ObjectCodes | Type[ct.py_object], align: bool = ..., copy: bool = ...) -> dtype[object_]: ...

# dtype of a dtype is the same dtype
@overload
Expand All @@ -1049,14 +1055,22 @@ class dtype(Generic[_DTypeScalar_co]):
align: bool = ...,
copy: bool = ...,
) -> dtype[Any]: ...
# Catchall overload
# Catchall overload for void-likes
@overload
def __new__(
cls,
dtype: _VoidDTypeLike,
align: bool = ...,
copy: bool = ...,
) -> dtype[void]: ...
# Catchall overload for object-likes
@overload
def __new__(
cls,
dtype: Type[object],
align: bool = ...,
copy: bool = ...,
) -> dtype[object_]: ...

@overload
def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ...
Expand Down
2 changes: 0 additions & 2 deletions numpy/typing/tests/data/fail/dtype.py
Expand Up @@ -18,5 +18,3 @@ class Test2:
"field2": (int, 3),
}
)

np.dtype[np.float64](np.int64) # E: Argument 1 to "dtype" has incompatible type
10 changes: 10 additions & 0 deletions numpy/typing/tests/data/reveal/dtype.py
@@ -1,3 +1,4 @@
import ctypes as ct
import numpy as np

dtype_obj: np.dtype[np.str_]
Expand All @@ -22,6 +23,15 @@
reveal_type(np.dtype(bool)) # E: numpy.dtype[numpy.bool_]
reveal_type(np.dtype(str)) # E: numpy.dtype[numpy.str_]
reveal_type(np.dtype(bytes)) # E: numpy.dtype[numpy.bytes_]
reveal_type(np.dtype(object)) # E: numpy.dtype[numpy.object_]

# ctypes
reveal_type(np.dtype(ct.c_double)) # E: numpy.dtype[{double}]
reveal_type(np.dtype(ct.c_longlong)) # E: numpy.dtype[{longlong}]
reveal_type(np.dtype(ct.c_uint32)) # E: numpy.dtype[{uint32}]
reveal_type(np.dtype(ct.c_bool)) # E: numpy.dtype[numpy.bool_]
reveal_type(np.dtype(ct.c_char)) # E: numpy.dtype[numpy.bytes_]
reveal_type(np.dtype(ct.py_object)) # E: numpy.dtype[numpy.object_]

# Special case for None
reveal_type(np.dtype(None)) # E: numpy.dtype[{double}]
Expand Down