Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for np.diag_indices() and np.diag_indices_from() #9115

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ The following top-level functions are supported:

* :func:`numpy.delete` (only the 2 first arguments)
* :func:`numpy.diag`
* :func:`numpy.diag_indices` (returns a NumPy array instead of a tuple)
* :func:`numpy.diag_indices_from` (returns a NumPy array instead of a tuple)
* :func:`numpy.digitize`
* :func:`numpy.dstack`
* :func:`numpy.dtype` (only the first argument)
Expand Down
4 changes: 4 additions & 0 deletions docs/upcoming_changes/9115.np_support.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

Add np.diag_indices() and np.diag_indices_from()
================================================
Support for np.diag_indices() and np.diag_indices_from() is added.
39 changes: 39 additions & 0 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4523,6 +4523,45 @@ def impl(N, M=None, k=0, dtype=float):
return impl


@overload(np.diag_indices)
def np_diag_indices(n, ndim=2):
if not isinstance(n, types.Integer):
msg = 'The argument "n" must be an integer'
raise errors.TypingError(msg)

if not isinstance(ndim, (int, types.Integer)):
msg = 'The argument "ndim" must be an integer'
raise errors.TypingError(msg)

def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
return res.reshape((ndim, n))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba has a function called tuple_setitem from cpython/unsafe/tuple.py where you can use to set items in a tuple at runtime.

Suggested change
def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
return res.reshape((ndim, n))
# tup_init is used only to get the correct tuple type
tup_init = (np.arange(1),) * 2
from numba.cpython.unsafe.tuple import tuple_setitem
def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
x = res.reshape((ndim, n))
tup = tup_init
for i in range(n):
tup = tuple_setitem(tup, i, x[i])
return tup

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @guilhermeleobas, thank you for your suggestion. The only issue with this implementation is that the array elements that are being returned are immutable, which is different to what the NumPy function np.diagflat returns:

import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
    res = np.diag_indices(n, ndim)
    return res
a=np.diag_indices(5)
print(a)
a[0][0] = 5
print(a)
c=foodiagindices(5)
print(c)
c[0][0] = 5

Returns

(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
(array([5, 1, 2, 3, 4]), array([5, 1, 2, 3, 4]))
(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
Traceback (most recent call last):
File "/home/kristian/Desktop/numba/a.py", line 13, in
c[0][0] = 5
ValueError: assignment destination is read-only

Another issue is that your proposed implementation implicitly assumes that ndim is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2. Hence, the following code:

import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
    res = np.diag_indices(n, ndim)
    return res
c = foodiagindices(5, 3)

throws a segmentation fault.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue is that your proposed implementation implicitly assumes that ndim is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2. Hence, the following code

Oh, I know. My code was just an example on how to use tuple_setitem


return impl


@overload(np.diag_indices_from)
def np_diag_indices_from(arr):
if not isinstance(arr, types.Array):
msg = 'The argument "arr" must be an array'
raise errors.TypingError(msg)

def impl(arr):
if not arr.ndim >= 2:
raise ValueError("Input array must be at least 2-d")
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
s = np.asarray(arr.shape)
if not np.all(np.diff(s) == 0):
raise ValueError("All dimensions of input must be of equal length")

return np.diag_indices(arr.shape[0], arr.ndim)

return impl


@overload(np.diag)
def impl_np_diag(v, k=0):
if not type_can_asarray(v):
Expand Down
109 changes: 109 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,18 @@ def nan_to_num(X, copy=True, nan=0.0):
return np.nan_to_num(X, copy=copy, nan=nan)


def diag_indices1(n):
return np.diag_indices(n)


def diag_indices2(n, ndim):
return np.diag_indices(n, ndim)


def diag_indices_from(arr):
return np.diag_indices_from(arr)


class TestNPFunctions(MemoryLeakMixin, TestCase):
"""
Tests for various Numpy functions.
Expand Down Expand Up @@ -5547,6 +5559,103 @@ def test_nan_to_num_invalid_argument(self):
self.assertIn("The first argument must be a scalar or an array-like",
str(raises.exception))

def test_diag_indices_basic(self):
pyfunc1 = diag_indices1
cfunc1 = njit(pyfunc1)
pyfunc2 = diag_indices2
cfunc2 = njit(pyfunc2)

def inputs():
# Based on https://github.com/numpy/numpy/blob/...
# 09ee7daa9113ecb503552913d76185d0ad3f0f31/numpy/lib/tests/...
# test_index_tricks.py#L483-L508
yield 4, 2
yield 2, 3

for n, ndim in inputs():
self.assertPreciseEqual(np.asarray(pyfunc1(n)), cfunc1(n))
self.assertPreciseEqual(np.asarray(pyfunc2(n, ndim)),
cfunc2(n, ndim))

def test_diag_indices1_exception(self):
pyfunc = diag_indices1
cfunc = njit(pyfunc)

self.disable_leak_check()

with self.assertRaises(TypingError) as raises:
cfunc("abc")
self.assertIn('The argument "n" must be an integer',
str(raises.exception))

with self.assertRaises(TypingError) as raises:
cfunc(3.0)
self.assertIn('The argument "n" must be an integer',
str(raises.exception))

def test_diag_indices2_exception(self):
pyfunc = diag_indices2
cfunc = njit(pyfunc)

self.disable_leak_check()

with self.assertRaises(TypingError) as raises:
cfunc("abc", 2)
self.assertIn('The argument "n" must be an integer',
str(raises.exception))

with self.assertRaises(TypingError) as raises:
cfunc(4, "abc")
self.assertIn('The argument "ndim" must be an integer',
str(raises.exception))

with self.assertRaises(TypingError) as raises:
cfunc(4, 3.0)
self.assertIn('The argument "ndim" must be an integer',
str(raises.exception))

def test_diag_indices_from_basic(self):
pyfunc = diag_indices_from
cfunc = njit(pyfunc)

def inputs():
# Taken from https://github.com/numpy/numpy/blob/...
# 09ee7daa9113ecb503552913d76185d0ad3f0f31/numpy/lib/tests/...
# test_index_tricks.py#L513-L517
KrisMinchev marked this conversation as resolved.
Show resolved Hide resolved
yield np.arange(16).reshape((4, 4))

for arr in inputs():
self.assertPreciseEqual(np.asarray(pyfunc(arr)), cfunc(arr))

def test_diag_indices_from_exception(self):
pyfunc = diag_indices_from
cfunc = njit(pyfunc)

self.disable_leak_check()

with self.assertRaises(TypingError) as raises:
cfunc("abc")
self.assertIn('The argument "arr" must be an array',
str(raises.exception))

with self.assertRaises(TypingError) as raises:
cfunc("abc")
self.assertIn('The argument "arr" must be an array',
str(raises.exception))

# Based on tests from https://github.com/numpy/numpy/blob/...
# 09ee7daa9113ecb503552913d76185d0ad3f0f31/numpy/lib/tests/...
# test_index_tricks.py#L519-L527
KrisMinchev marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError) as raises:
cfunc(np.ones(7))
self.assertIn('Input array must be at least 2-d',
str(raises.exception))

with self.assertRaises(ValueError) as raises:
cfunc(np.zeros((3, 3, 2, 3)))
self.assertIn('All dimensions of input must be of equal length',
str(raises.exception))


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr
Expand Down