Skip to content

Commit

Permalink
Merge pull request #19269 from charris/backport-19228
Browse files Browse the repository at this point in the history
BUG: Invalid dtypes comparison should not raise TypeError
  • Loading branch information
charris committed Jun 17, 2021
2 parents a070e5d + d80e473 commit 143d45f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 0 additions & 2 deletions numpy/__init__.pyi
Expand Up @@ -1075,8 +1075,6 @@ class dtype(Generic[_DTypeScalar_co]):
# literals as of mypy 0.800. Set the return-type to `Any` for now.
def __rmul__(self, value: int) -> Any: ...

def __eq__(self, other: DTypeLike) -> bool: ...
def __ne__(self, other: DTypeLike) -> bool: ...
def __gt__(self, other: DTypeLike) -> bool: ...
def __ge__(self, other: DTypeLike) -> bool: ...
def __lt__(self, other: DTypeLike) -> bool: ...
Expand Down
4 changes: 3 additions & 1 deletion numpy/core/src/multiarray/descriptor.c
Expand Up @@ -3228,7 +3228,9 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op)
{
PyArray_Descr *new = _convert_from_any(other, 0);
if (new == NULL) {
return NULL;
/* Cannot convert `other` to dtype */
PyErr_Clear();
Py_RETURN_NOTIMPLEMENTED;
}

npy_bool ret;
Expand Down
18 changes: 18 additions & 0 deletions numpy/core/tests/test_dtype.py
Expand Up @@ -88,6 +88,24 @@ def test_invalid_types(self):
assert_raises(TypeError, np.dtype, 'q8')
assert_raises(TypeError, np.dtype, 'Q8')

def test_richcompare_invalid_dtype_equality(self):
# Make sure objects that cannot be converted to valid
# dtypes results in False/True when compared to valid dtypes.
# Here 7 cannot be converted to dtype. No exceptions should be raised

assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="

@pytest.mark.parametrize(
'operation',
[operator.le, operator.lt, operator.ge, operator.gt])
def test_richcompare_invalid_dtype_comparison(self, operation):
# Make sure TypeError is raised for comparison operators
# for invalid dtypes. Here 7 is an invalid dtype.

with pytest.raises(TypeError):
operation(np.dtype(np.int32), 7)

@pytest.mark.parametrize("dtype",
['Bool', 'Complex32', 'Complex64', 'Float16', 'Float32', 'Float64',
'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Timedelta64',
Expand Down

0 comments on commit 143d45f

Please sign in to comment.