Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Invalid dtypes comparison should not raise TypeError #19269

Merged
merged 5 commits into from Jun 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions numpy/__init__.pyi
Expand Up @@ -1075,8 +1075,6 @@ class dtype(Generic[_DTypeScalar_co]):
# literals as of mypy 0.800. Set the return-type to `Any` for now.
def __rmul__(self, value: int) -> Any: ...

def __eq__(self, other: DTypeLike) -> bool: ...
def __ne__(self, other: DTypeLike) -> bool: ...
def __gt__(self, other: DTypeLike) -> bool: ...
def __ge__(self, other: DTypeLike) -> bool: ...
def __lt__(self, other: DTypeLike) -> bool: ...
Expand Down
4 changes: 3 additions & 1 deletion numpy/core/src/multiarray/descriptor.c
Expand Up @@ -3228,7 +3228,9 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op)
{
PyArray_Descr *new = _convert_from_any(other, 0);
if (new == NULL) {
return NULL;
/* Cannot convert `other` to dtype */
PyErr_Clear();
Py_RETURN_NOTIMPLEMENTED;
}

npy_bool ret;
Expand Down
18 changes: 18 additions & 0 deletions numpy/core/tests/test_dtype.py
Expand Up @@ -88,6 +88,24 @@ def test_invalid_types(self):
assert_raises(TypeError, np.dtype, 'q8')
assert_raises(TypeError, np.dtype, 'Q8')

def test_richcompare_invalid_dtype_equality(self):
# Make sure objects that cannot be converted to valid
# dtypes results in False/True when compared to valid dtypes.
# Here 7 cannot be converted to dtype. No exceptions should be raised

assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="

@pytest.mark.parametrize(
'operation',
[operator.le, operator.lt, operator.ge, operator.gt])
def test_richcompare_invalid_dtype_comparison(self, operation):
# Make sure TypeError is raised for comparison operators
# for invalid dtypes. Here 7 is an invalid dtype.

with pytest.raises(TypeError):
operation(np.dtype(np.int32), 7)

@pytest.mark.parametrize("dtype",
['Bool', 'Complex32', 'Complex64', 'Float16', 'Float32', 'Float64',
'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Timedelta64',
Expand Down