Skip to content

Commit

Permalink
Merge pull request #19503 from charris/backport-19468
Browse files Browse the repository at this point in the history
MAINT: Add missing dtype overloads for object- and ctypes-based inputs
  • Loading branch information
charris committed Jul 17, 2021
2 parents 8f07655 + cadd4ba commit 7b27ba9
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 34 deletions.
78 changes: 46 additions & 32 deletions numpy/__init__.pyi
@@ -1,6 +1,9 @@
import builtins
import os
import sys
import mmap
import ctypes as ct
import array as _array
import datetime as dt
from abc import abstractmethod
from types import TracebackType
Expand Down Expand Up @@ -920,7 +923,7 @@ class dtype(Generic[_DTypeScalar_co]):
# other special cases. Order is sometimes important because of the
# subtype relationships
#
# bool < int < float < complex
# bool < int < float < complex < object
#
# so we have to make sure the overloads for the narrowest type is
# first.
Expand All @@ -938,51 +941,54 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: Type[bytes], align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...

# `unsignedinteger` string-based representations
# `unsignedinteger` string-based representations and ctypes
@overload
def __new__(cls, dtype: _UInt8Codes, align: bool = ..., copy: bool = ...) -> dtype[uint8]: ...
def __new__(cls, dtype: _UInt8Codes | Type[ct.c_uint8], align: bool = ..., copy: bool = ...) -> dtype[uint8]: ...
@overload
def __new__(cls, dtype: _UInt16Codes, align: bool = ..., copy: bool = ...) -> dtype[uint16]: ...
def __new__(cls, dtype: _UInt16Codes | Type[ct.c_uint16], align: bool = ..., copy: bool = ...) -> dtype[uint16]: ...
@overload
def __new__(cls, dtype: _UInt32Codes, align: bool = ..., copy: bool = ...) -> dtype[uint32]: ...
def __new__(cls, dtype: _UInt32Codes | Type[ct.c_uint32], align: bool = ..., copy: bool = ...) -> dtype[uint32]: ...
@overload
def __new__(cls, dtype: _UInt64Codes, align: bool = ..., copy: bool = ...) -> dtype[uint64]: ...
def __new__(cls, dtype: _UInt64Codes | Type[ct.c_uint64], align: bool = ..., copy: bool = ...) -> dtype[uint64]: ...
@overload
def __new__(cls, dtype: _UByteCodes, align: bool = ..., copy: bool = ...) -> dtype[ubyte]: ...
def __new__(cls, dtype: _UByteCodes | Type[ct.c_ubyte], align: bool = ..., copy: bool = ...) -> dtype[ubyte]: ...
@overload
def __new__(cls, dtype: _UShortCodes, align: bool = ..., copy: bool = ...) -> dtype[ushort]: ...
def __new__(cls, dtype: _UShortCodes | Type[ct.c_ushort], align: bool = ..., copy: bool = ...) -> dtype[ushort]: ...
@overload
def __new__(cls, dtype: _UIntCCodes, align: bool = ..., copy: bool = ...) -> dtype[uintc]: ...
def __new__(cls, dtype: _UIntCCodes | Type[ct.c_uint], align: bool = ..., copy: bool = ...) -> dtype[uintc]: ...

# NOTE: We're assuming here that `uint_ptr_t == size_t`,
# an assumption that does not hold in rare cases (same for `ssize_t`)
@overload
def __new__(cls, dtype: _UIntPCodes, align: bool = ..., copy: bool = ...) -> dtype[uintp]: ...
def __new__(cls, dtype: _UIntPCodes | Type[ct.c_void_p] | Type[ct.c_size_t], align: bool = ..., copy: bool = ...) -> dtype[uintp]: ...
@overload
def __new__(cls, dtype: _UIntCodes, align: bool = ..., copy: bool = ...) -> dtype[uint]: ...
def __new__(cls, dtype: _UIntCodes | Type[ct.c_ulong], align: bool = ..., copy: bool = ...) -> dtype[uint]: ...
@overload
def __new__(cls, dtype: _ULongLongCodes, align: bool = ..., copy: bool = ...) -> dtype[ulonglong]: ...
def __new__(cls, dtype: _ULongLongCodes | Type[ct.c_ulonglong], align: bool = ..., copy: bool = ...) -> dtype[ulonglong]: ...

# `signedinteger` string-based representations
# `signedinteger` string-based representations and ctypes
@overload
def __new__(cls, dtype: _Int8Codes, align: bool = ..., copy: bool = ...) -> dtype[int8]: ...
def __new__(cls, dtype: _Int8Codes | Type[ct.c_int8], align: bool = ..., copy: bool = ...) -> dtype[int8]: ...
@overload
def __new__(cls, dtype: _Int16Codes, align: bool = ..., copy: bool = ...) -> dtype[int16]: ...
def __new__(cls, dtype: _Int16Codes | Type[ct.c_int16], align: bool = ..., copy: bool = ...) -> dtype[int16]: ...
@overload
def __new__(cls, dtype: _Int32Codes, align: bool = ..., copy: bool = ...) -> dtype[int32]: ...
def __new__(cls, dtype: _Int32Codes | Type[ct.c_int32], align: bool = ..., copy: bool = ...) -> dtype[int32]: ...
@overload
def __new__(cls, dtype: _Int64Codes, align: bool = ..., copy: bool = ...) -> dtype[int64]: ...
def __new__(cls, dtype: _Int64Codes | Type[ct.c_int64], align: bool = ..., copy: bool = ...) -> dtype[int64]: ...
@overload
def __new__(cls, dtype: _ByteCodes, align: bool = ..., copy: bool = ...) -> dtype[byte]: ...
def __new__(cls, dtype: _ByteCodes | Type[ct.c_byte], align: bool = ..., copy: bool = ...) -> dtype[byte]: ...
@overload
def __new__(cls, dtype: _ShortCodes, align: bool = ..., copy: bool = ...) -> dtype[short]: ...
def __new__(cls, dtype: _ShortCodes | Type[ct.c_short], align: bool = ..., copy: bool = ...) -> dtype[short]: ...
@overload
def __new__(cls, dtype: _IntCCodes, align: bool = ..., copy: bool = ...) -> dtype[intc]: ...
def __new__(cls, dtype: _IntCCodes | Type[ct.c_int], align: bool = ..., copy: bool = ...) -> dtype[intc]: ...
@overload
def __new__(cls, dtype: _IntPCodes, align: bool = ..., copy: bool = ...) -> dtype[intp]: ...
def __new__(cls, dtype: _IntPCodes | Type[ct.c_ssize_t], align: bool = ..., copy: bool = ...) -> dtype[intp]: ...
@overload
def __new__(cls, dtype: _IntCodes, align: bool = ..., copy: bool = ...) -> dtype[int_]: ...
def __new__(cls, dtype: _IntCodes | Type[ct.c_long], align: bool = ..., copy: bool = ...) -> dtype[int_]: ...
@overload
def __new__(cls, dtype: _LongLongCodes, align: bool = ..., copy: bool = ...) -> dtype[longlong]: ...
def __new__(cls, dtype: _LongLongCodes | Type[ct.c_longlong], align: bool = ..., copy: bool = ...) -> dtype[longlong]: ...

# `floating` string-based representations
# `floating` string-based representations and ctypes
@overload
def __new__(cls, dtype: _Float16Codes, align: bool = ..., copy: bool = ...) -> dtype[float16]: ...
@overload
Expand All @@ -992,11 +998,11 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: _HalfCodes, align: bool = ..., copy: bool = ...) -> dtype[half]: ...
@overload
def __new__(cls, dtype: _SingleCodes, align: bool = ..., copy: bool = ...) -> dtype[single]: ...
def __new__(cls, dtype: _SingleCodes | Type[ct.c_float], align: bool = ..., copy: bool = ...) -> dtype[single]: ...
@overload
def __new__(cls, dtype: _DoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[double]: ...
def __new__(cls, dtype: _DoubleCodes | Type[ct.c_double], align: bool = ..., copy: bool = ...) -> dtype[double]: ...
@overload
def __new__(cls, dtype: _LongDoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[longdouble]: ...
def __new__(cls, dtype: _LongDoubleCodes | Type[ct.c_longdouble], align: bool = ..., copy: bool = ...) -> dtype[longdouble]: ...

# `complexfloating` string-based representations
@overload
Expand All @@ -1010,21 +1016,21 @@ class dtype(Generic[_DTypeScalar_co]):
@overload
def __new__(cls, dtype: _CLongDoubleCodes, align: bool = ..., copy: bool = ...) -> dtype[clongdouble]: ...

# Miscellaneous string-based representations
# Miscellaneous string-based representations and ctypes
@overload
def __new__(cls, dtype: _BoolCodes, align: bool = ..., copy: bool = ...) -> dtype[bool_]: ...
def __new__(cls, dtype: _BoolCodes | Type[ct.c_bool], align: bool = ..., copy: bool = ...) -> dtype[bool_]: ...
@overload
def __new__(cls, dtype: _TD64Codes, align: bool = ..., copy: bool = ...) -> dtype[timedelta64]: ...
@overload
def __new__(cls, dtype: _DT64Codes, align: bool = ..., copy: bool = ...) -> dtype[datetime64]: ...
@overload
def __new__(cls, dtype: _StrCodes, align: bool = ..., copy: bool = ...) -> dtype[str_]: ...
@overload
def __new__(cls, dtype: _BytesCodes, align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...
def __new__(cls, dtype: _BytesCodes | Type[ct.c_char], align: bool = ..., copy: bool = ...) -> dtype[bytes_]: ...
@overload
def __new__(cls, dtype: _VoidCodes, align: bool = ..., copy: bool = ...) -> dtype[void]: ...
@overload
def __new__(cls, dtype: _ObjectCodes, align: bool = ..., copy: bool = ...) -> dtype[object_]: ...
def __new__(cls, dtype: _ObjectCodes | Type[ct.py_object], align: bool = ..., copy: bool = ...) -> dtype[object_]: ...

# dtype of a dtype is the same dtype
@overload
Expand All @@ -1049,14 +1055,22 @@ class dtype(Generic[_DTypeScalar_co]):
align: bool = ...,
copy: bool = ...,
) -> dtype[Any]: ...
# Catchall overload
# Catchall overload for void-likes
@overload
def __new__(
cls,
dtype: _VoidDTypeLike,
align: bool = ...,
copy: bool = ...,
) -> dtype[void]: ...
# Catchall overload for object-likes
@overload
def __new__(
cls,
dtype: Type[object],
align: bool = ...,
copy: bool = ...,
) -> dtype[object_]: ...

@overload
def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ...
Expand Down
2 changes: 0 additions & 2 deletions numpy/typing/tests/data/fail/dtype.py
Expand Up @@ -18,5 +18,3 @@ class Test2:
"field2": (int, 3),
}
)

np.dtype[np.float64](np.int64) # E: Argument 1 to "dtype" has incompatible type
10 changes: 10 additions & 0 deletions numpy/typing/tests/data/reveal/dtype.py
@@ -1,3 +1,4 @@
import ctypes as ct
import numpy as np

dtype_obj: np.dtype[np.str_]
Expand All @@ -22,6 +23,15 @@
reveal_type(np.dtype(bool)) # E: numpy.dtype[numpy.bool_]
reveal_type(np.dtype(str)) # E: numpy.dtype[numpy.str_]
reveal_type(np.dtype(bytes)) # E: numpy.dtype[numpy.bytes_]
reveal_type(np.dtype(object)) # E: numpy.dtype[numpy.object_]

# ctypes
reveal_type(np.dtype(ct.c_double)) # E: numpy.dtype[{double}]
reveal_type(np.dtype(ct.c_longlong)) # E: numpy.dtype[{longlong}]
reveal_type(np.dtype(ct.c_uint32)) # E: numpy.dtype[{uint32}]
reveal_type(np.dtype(ct.c_bool)) # E: numpy.dtype[numpy.bool_]
reveal_type(np.dtype(ct.c_char)) # E: numpy.dtype[numpy.bytes_]
reveal_type(np.dtype(ct.py_object)) # E: numpy.dtype[numpy.object_]

# Special case for None
reveal_type(np.dtype(None)) # E: numpy.dtype[{double}]
Expand Down

0 comments on commit 7b27ba9

Please sign in to comment.