Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dask.array.array_equal #10740

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
argwhere,
around,
array,
array_equal,
atleast_1d,
atleast_2d,
atleast_3d,
Expand Down
33 changes: 32 additions & 1 deletion dask/array/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dask.array.creation import arange, diag, empty, indices, tri
from dask.array.einsumfuncs import einsum # noqa
from dask.array.reductions import reduction
from dask.array.ufunc import multiply, sqrt
from dask.array.ufunc import isnan, multiply, sqrt
from dask.array.utils import (
array_safe,
asarray_safe,
Expand Down Expand Up @@ -2107,6 +2107,37 @@ def where(condition, x=None, y=None):
return elemwise(np.where, condition, x, y)


_no_nan_types = {
type(np.dtype(t))
for t in (
np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
)
}


@derived_from(np)
def array_equal(a1, a2, equal_nan=False, split_every=None):
if a1.shape != a2.shape:
return array([np.False_])[0]
cannot_have_nan = (
type(a1.dtype) in _no_nan_types and type(a2.dtype) in _no_nan_types
)
if (equal_nan or cannot_have_nan) and (a1 is a2):
return array([np.True_])[0]
equal = a1 == a2
if equal_nan and not cannot_have_nan:
equal = where(isnan(a1) & isnan(a2), True, equal)
return equal.all(split_every=split_every)


@derived_from(np)
def count_nonzero(a, axis=None):
return isnonzero(asarray(a)).astype(np.intp).sum(axis=axis)
Expand Down
143 changes: 143 additions & 0 deletions dask/array/tests/test_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,149 @@ def test_where_incorrect_args():
assert "either both or neither of x and y should be given" in str(e)


def _test_array_equal_parametrizations():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use a proper fixture here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand, could you explain in a bit more detail what you would like to see, please? I read through all the documentation on pytest.mark.parametrize and fixtures, but I cannot find how using a fixture would facilitate parametrizing a test. Or do you mean to drop pytest.mark.parametrize and write a single test that loops over all the test cases? That would prevent running the tests in parallel though.

As a side note: this test setup was copied from numpy/_core/tests/test_numeric.py, so it might be easier for future maintenance to keep modifications to a minimum.

"""
we pre-create arrays as we sometime want to pass the same instance
and sometime not. Passing the same instances may not mean the array are
equal, especially when containing None
"""
# Test cases copied from `numpy/numpy/_core/tests/test_numeric.py`

# those are 0-d arrays, it used to be a special case
# where (e0 == e0).all() would raise
e0 = np.array(0, dtype="int")
e1 = np.array(1, dtype="float")
# a1, a2, equal_nan
yield (e0, e0.copy(), None)
yield (e0, e0.copy(), False)
yield (e0, e0.copy(), True)

#
yield (e1, e1.copy(), None)
yield (e1, e1.copy(), False)
yield (e1, e1.copy(), True)

# Non-nanable - those cannot hold nans
a12 = np.array([1, 2])
a12b = a12.copy()
a123 = np.array([1, 2, 3])
a13 = np.array([1, 3])
a34 = np.array([3, 4])

aS1 = np.array(["a"], dtype="S1")
aS1b = aS1.copy()

yield (a12, a12b, None)
yield (a12, a12, None)
yield (a12, a123, None)
yield (a12, a34, None)
yield (a12, a13, None)
yield (aS1, aS1b, None)
yield (aS1, aS1, None)

# Non-float dtype - equal_nan should have no effect,
yield (a123, a123, None)
yield (a123, a123, False)
yield (a123, a123, True)
yield (a123, a123.copy(), None)
yield (a123, a123.copy(), False)
yield (a123, a123.copy(), True)
yield (a123.astype("float"), a123.astype("float"), None)
yield (a123.astype("float"), a123.astype("float"), False)
yield (a123.astype("float"), a123.astype("float"), True)

# these can hold None
b1 = np.array([1, 2, np.nan])
b2 = np.array([1, np.nan, 2])
b3 = np.array([1, 2, np.inf])
b4 = np.array(np.nan)

# instances are the same
yield (b1, b1, None)
yield (b1, b1, False)
yield (b1, b1, True)

# equal but not same instance
yield (b1, b1.copy(), None)
yield (b1, b1.copy(), False)
yield (b1, b1.copy(), True)

# same once stripped of Nan
yield (b1, b2, None)
yield (b1, b2, False)
yield (b1, b2, True)

# nan's not conflated with inf's
yield (b1, b3, None)
yield (b1, b3, False)
yield (b1, b3, True)

# all Nan
yield (b4, b4, None)
yield (b4, b4, False)
yield (b4, b4, True)
yield (b4, b4.copy(), None)
yield (b4, b4.copy(), False)
yield (b4, b4.copy(), True)

t1 = b1.astype("timedelta64")
t2 = b2.astype("timedelta64")

# Timedeltas are particular
yield (t1, t1, None)
yield (t1, t1, False)
yield (t1, t1, True)

yield (t1, t1.copy(), None)
yield (t1, t1.copy(), False)
yield (t1, t1.copy(), True)

yield (t1, t2, None)
yield (t1, t2, False)
yield (t1, t2, True)

# Multi-dimensional array
md1 = np.array([[0, 1], [np.nan, 1]])

yield (md1, md1, None)
yield (md1, md1, False)
yield (md1, md1, True)
yield (md1, md1.copy(), None)
yield (md1, md1.copy(), False)
yield (md1, md1.copy(), True)
# both complexes are nan+nan.j but the same instance
cplx1, cplx2 = [np.array([np.nan + np.nan * 1j])] * 2

# only real or img are nan.
cplx3, cplx4 = np.complex64(1 + 1j * np.nan), np.complex64(np.nan + 1j)

# Complex values
yield (cplx1, cplx2, None)
yield (cplx1, cplx2, False)
yield (cplx1, cplx2, True)

# Complex values, 1+nan, nan+1j
yield (cplx3, cplx4, None)
yield (cplx3, cplx4, False)
yield (cplx3, cplx4, True)


@pytest.mark.parametrize("a1,a2,equal_nan", _test_array_equal_parametrizations())
def test_array_equal(a1, a2, equal_nan):
d1 = da.asarray(a1, chunks=2)
if a1 is a2:
d2 = d1
else:
d2 = da.asarray(a2, chunks=1)
if equal_nan is None:
np_eq = np.array_equal(a1, a2)
da_eq = da.array_equal(d1, d2)
else:
np_eq = np.array_equal(a1, a2, equal_nan=equal_nan)
da_eq = da.array_equal(d1, d2, equal_nan=equal_nan)
assert_eq(np_eq, da_eq)


def test_count_nonzero():
for shape, chunks in [(0, ()), ((0, 0), (0, 0)), ((15, 16), (4, 5))]:
x = np.random.default_rng().integers(10, size=shape)
Expand Down
1 change: 1 addition & 0 deletions docs/source/array-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Top level functions
argwhere
around
array
array_equal
asanyarray
asarray
atleast_1d
Expand Down