Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add np.allclose and np.isclose support ref. issue #4074 #6286

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ Other functions

The following top-level functions are supported:

* :func:`numpy.allclose`
* :func:`numpy.append`
* :func:`numpy.arange`
* :func:`numpy.argsort` (``kind`` key word argument supported for values
Expand Down Expand Up @@ -388,6 +389,7 @@ The following top-level functions are supported:
* :func:`numpy.identity`
* :func:`numpy.kaiser`
* :func:`numpy.interp` (only the 3 first arguments; requires NumPy >= 1.10)
* :func:`numpy.isclose` (will return 1d array for 0d array input)
* :func:`numpy.linspace` (only the 3-argument form)
* :class:`numpy.ndenumerate`
* :class:`numpy.ndindex`
Expand Down
84 changes: 84 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4436,3 +4436,87 @@ def impl(a, b):
return _cross2d_operation(a_, b_)

return impl


@register_jitable
def _broadcastable_(a, b):
s = min(a.ndim, b.ndim)
for i in range(1, s + 1):
m = a.shape[-i]
n = b.shape[-i]
if not((m == n) or (m == 1) or (n == 1)):
return False
return True


@register_jitable
def _close_operation(a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
if np.isnan(a) or np.isnan(b):
if equal_nan and np.isnan(a) and np.isnan(b):
return True
else:
return False

if abs(a - b) > (atol + rtol * abs(b)):
return False
return True


@overload(np.isclose)
def np_isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
if not type_can_asarray(a) or not type_can_asarray(b):
raise TypingError("Inputs must be array-like.")

def impl(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
x = np.atleast_1d(a)
y = np.atleast_1d(b)

def within_tol(x, y, rtol, atol):
return np.abs(x - y) <= (atol + rtol * np.abs(y))

xfin = np.isfinite(x)
yfin = np.isfinite(y)

if np.all(xfin) and np.all(yfin):
res = within_tol(x, y, rtol, atol)
else:
finite = xfin & yfin
cond = np.zeros_like(finite)

finite_f = finite.ravel()

x = x * np.ones_like(cond)
y = y * np.ones_like(cond)

x_f = x.ravel()
y_f = y.ravel()
cond_f = cond.ravel()

cond_f[finite_f] = within_tol(x_f[finite_f], y_f[finite_f],
rtol, atol)

cond_f[~finite_f] = (x_f[~finite_f] == y_f[~finite_f])

if equal_nan:
both_nan = (np.isnan(x) & np.isnan(y)).ravel()
cond_f[both_nan] = both_nan[both_nan]

res = cond[()]

return res

return impl


@overload(np.allclose)
def np_allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
if not type_can_asarray(a) or not type_can_asarray(b):
raise TypingError("Inputs must be array-like.")

def impl(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):

return np.all(
np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
)

return impl
85 changes: 85 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ def np_cross(a, b):
return np.cross(a, b)


def np_isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
return np.isclose(a, b, rtol, atol, equal_nan)


def np_allclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
return np.allclose(a, b, rtol, atol, equal_nan)


def flip_lr(a):
return np.fliplr(a)

Expand Down Expand Up @@ -3712,6 +3720,83 @@ def test_cross2d_exceptions(self):
str(raises.exception)
)

def test_isclose(self):
pyfunc = np_isclose
cfunc = jit(nopython=True)(pyfunc)

args_list = [
([1e10, 1e-7], [1.00001e10, 1e-8], {}),
([1e10, 1e-8], [1.00001e10, 1e-9], {}),
([1e10, 1e-8], [1.0001e10, 1e-9], {}),
([1.0, np.nan], [1.0, np.nan], {}),
([1.0, np.nan], [1.0, np.nan], {'equal_nan': True}),
(np.random.rand(4, 5), np.random.rand(4, 5), {'atol': 2.0}),
(np.random.rand(4, 5), np.random.rand(4, 5), {})
]

for args in args_list:
x, y = args[:-1]
x = np.asarray(x)
y = np.asarray(y)
kwargs = args[-1]
expected = pyfunc(x, y, **kwargs)
got = cfunc(x, y, **kwargs)
self.assertPreciseEqual(expected, got)

def test_isclose_exceptions(self):
cfunc = jit(nopython=True)(np_isclose)
self.disable_leak_check()

# test incompatible dimensions
with self.assertRaises(ValueError) as raises:
cfunc(
np.random.rand(4, 3),
np.random.rand(3, 5)
)
self.assertIn(
'unable to broadcast',
str(raises.exception)
)

def test_allclose(self):
pyfunc = np_allclose
cfunc = jit(nopython=True)(pyfunc)

args_list = [
([1e10, 1e-7], [1.00001e10, 1e-8], {}),
([1e10, 1e-8], [1.00001e10, 1e-9], {}),
([1e10, 1e-8], [1.0001e10, 1e-9], {}),
([1.0, np.nan], [1.0, np.nan], {}),
([1.0, np.nan], [1.0, np.nan], {'equal_nan': True}),
(np.random.rand(4, 5), np.random.rand(4, 5), {'atol': 2.0}),
(np.random.rand(4, 5), np.random.rand(4, 5), {}),
(1, 1, {})
]

for args in args_list:
x, y = args[:-1]
x = np.asarray(x)
y = np.asarray(y)
kwargs = args[-1]
expected = pyfunc(x, y, **kwargs)
got = cfunc(x, y, **kwargs)
self.assertEqual(expected, got)

def test_allclose_exceptions(self):
cfunc = jit(nopython=True)(np_allclose)
self.disable_leak_check()

# test incompatible dimensions
with self.assertRaises(ValueError) as raises:
cfunc(
np.random.rand(4, 3),
np.random.rand(3, 5)
)
self.assertIn(
'unable to broadcast',
str(raises.exception)
)


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr
Expand Down