Skip to content

Commit

Permalink
BUG: Make sure that NumPy scalars are supported by can_cast
Browse files Browse the repository at this point in the history
The main issue here was the order of the checks, since float64 is
a subclass of float the error path was taken even though it should
not have been.

This also avoids converting to an array (which is very slow) when
possible.  I opted to use `scalar.dtype` since that may be a bit
easier for potential future user dtype.
That may not be quite ideal (I would like to not force `np.generic`
as a base-class for user scalars), but is probably pretty close
and more complicated fixes are probably not good for  backport.
  • Loading branch information
seberg authored and charris committed May 6, 2024
1 parent 3ebc9be commit 9aace32
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
3 changes: 2 additions & 1 deletion numpy/_core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "conversion_utils.h" /* for PyArray_TypestrConvert */
#include "templ_common.h" /* for npy_mul_sizes_with_overflow */
#include "descriptor.h"
#include "multiarraymodule.h"
#include "alloc.h"
#include "assert.h"
#include "npy_buffer.h"
Expand Down Expand Up @@ -2696,7 +2697,7 @@ arraydescr_reduce(PyArray_Descr *self, PyObject *NPY_UNUSED(args))
Py_DECREF(ret);
return NULL;
}
obj = PyObject_GetAttrString(mod, "dtype");
obj = PyObject_GetAttr(mod, npy_ma_str_dtype);
Py_DECREF(mod);
if (obj == NULL) {
Py_DECREF(ret);
Expand Down
44 changes: 35 additions & 9 deletions numpy/_core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3488,6 +3488,36 @@ array_can_cast_safely(PyObject *NPY_UNUSED(self),
if (PyArray_Check(from_obj)) {
ret = PyArray_CanCastArrayTo((PyArrayObject *)from_obj, d2, casting);
}
else if (PyArray_IsScalar(from_obj, Generic)) {
/*
* TODO: `PyArray_IsScalar` should not be required for new dtypes.
* weak-promotion branch is in practice identical to dtype one.
*/
if (npy_promotion_state == NPY_USE_WEAK_PROMOTION) {
PyObject *descr = PyObject_GetAttr(from_obj, npy_ma_str_dtype);
if (descr == NULL) {
goto finish;
}
if (!PyArray_DescrCheck(descr)) {
Py_DECREF(descr);
PyErr_SetString(PyExc_TypeError,
"numpy_scalar.dtype did not return a dtype instance.");
goto finish;
}
ret = PyArray_CanCastTypeTo((PyArray_Descr *)descr, d2, casting);
Py_DECREF(descr);
}
else {
/* need to convert to object to consider old value-based logic */
PyArrayObject *arr;
arr = (PyArrayObject *)PyArray_FROM_O(from_obj);
if (arr == NULL) {
goto finish;
}
ret = PyArray_CanCastArrayTo(arr, d2, casting);
Py_DECREF(arr);
}
}
else if (PyArray_IsPythonNumber(from_obj)) {
PyErr_SetString(PyExc_TypeError,
"can_cast() does not support Python ints, floats, and "
Expand All @@ -3496,15 +3526,6 @@ array_can_cast_safely(PyObject *NPY_UNUSED(self),
"explicitly allow them again in the future.");
goto finish;
}
else if (PyArray_IsScalar(from_obj, Generic)) {
PyArrayObject *arr;
arr = (PyArrayObject *)PyArray_FROM_O(from_obj);
if (arr == NULL) {
goto finish;
}
ret = PyArray_CanCastArrayTo(arr, d2, casting);
Py_DECREF(arr);
}
/* Otherwise use CanCastTypeTo */
else {
if (!PyArray_DescrConverter2(from_obj, &d1) || d1 == NULL) {
Expand Down Expand Up @@ -4772,6 +4793,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_convert = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_preserve = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_convert_if_no_array = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_cpu = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_dtype = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_err_msg_substr = NULL;

static int
Expand Down Expand Up @@ -4850,6 +4872,10 @@ intern_strings(void)
if (npy_ma_str_cpu == NULL) {
return -1;
}
npy_ma_str_dtype = PyUnicode_InternFromString("dtype");
if (npy_ma_str_dtype == NULL) {
return -1;
}
npy_ma_str_array_err_msg_substr = PyUnicode_InternFromString(
"__array__() got an unexpected keyword argument 'copy'");
if (npy_ma_str_array_err_msg_substr == NULL) {
Expand Down
1 change: 1 addition & 0 deletions numpy/_core/src/multiarray/multiarraymodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_convert;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_preserve;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_convert_if_no_array;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_cpu;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_err_msg_substr;

#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */
11 changes: 11 additions & 0 deletions numpy/_core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,17 @@ def test_can_cast_values(self):
assert_(np.can_cast(fi.min, dt))
assert_(np.can_cast(fi.max, dt))

@pytest.mark.parametrize("dtype",
list("?bhilqBHILQefdgFDG") + [rational])
def test_can_cast_scalars(self, dtype):
# Basic test to ensure that scalars are supported in can-cast
# (does not check behavior exhaustively).
dtype = np.dtype(dtype)
scalar = dtype.type(0)

assert np.can_cast(scalar, "int64") == np.can_cast(dtype, "int64")
assert np.can_cast(scalar, "float32", casting="unsafe")


# Custom exception class to test exception propagation in fromiter
class NIterError(Exception):
Expand Down

0 comments on commit 9aace32

Please sign in to comment.