Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for np.diag_indices() and np.diag_indices_from() #9115

Closed
wants to merge 5 commits into from

Conversation

KrisMinchev
Copy link
Contributor

This PR improves on #4074 by implementing the functions np.diag_indices() and np.diag_indices_from(). Note that the NumPy versions return tuples whereas the implementations presented here return NumPy arrays.

@KrisMinchev KrisMinchev marked this pull request as ready for review August 2, 2023 10:23
Copy link
Collaborator

@guilhermeleobas guilhermeleobas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @KrisMinchev, great work on adding these functions to Numba.

Note that the NumPy versions return tuples whereas the implementations presented here return NumPy arrays.

Numba tries to keep compatibility with NumPy as close as possible. I've left a suggestion on how you can return tuples instead of NumPy arrays. The idea is to create a tuple at compile-time with the same type as the tuple you will return at runtime, and use tuple_setitem to set each tuple item.

Comment on lines 4535 to 4540

def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
return res.reshape((ndim, n))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba has a function called tuple_setitem from cpython/unsafe/tuple.py where you can use to set items in a tuple at runtime.

Suggested change
def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
return res.reshape((ndim, n))
# tup_init is used only to get the correct tuple type
tup_init = (np.arange(1),) * 2
from numba.cpython.unsafe.tuple import tuple_setitem
def impl(n, ndim=2):
res = np.arange(n * ndim)
for i in range(n * ndim):
res[i] = res[i] % n
x = res.reshape((ndim, n))
tup = tup_init
for i in range(n):
tup = tuple_setitem(tup, i, x[i])
return tup

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @guilhermeleobas, thank you for your suggestion. The only issue with this implementation is that the array elements that are being returned are immutable, which is different to what the NumPy function np.diagflat returns:

import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
    res = np.diag_indices(n, ndim)
    return res
a=np.diag_indices(5)
print(a)
a[0][0] = 5
print(a)
c=foodiagindices(5)
print(c)
c[0][0] = 5

Returns

(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
(array([5, 1, 2, 3, 4]), array([5, 1, 2, 3, 4]))
(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
Traceback (most recent call last):
File "/home/kristian/Desktop/numba/a.py", line 13, in
c[0][0] = 5
ValueError: assignment destination is read-only

Another issue is that your proposed implementation implicitly assumes that ndim is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2. Hence, the following code:

import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
    res = np.diag_indices(n, ndim)
    return res
c = foodiagindices(5, 3)

throws a segmentation fault.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue is that your proposed implementation implicitly assumes that ndim is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2. Hence, the following code

Oh, I know. My code was just an example on how to use tuple_setitem

numba/tests/test_np_functions.py Outdated Show resolved Hide resolved
numba/tests/test_np_functions.py Outdated Show resolved Hide resolved
@guilhermeleobas guilhermeleobas added 4 - Waiting on author Waiting for author to respond to review numpy Effort - medium Medium size effort needed and removed 2 - In Progress labels Aug 2, 2023
@guilhermeleobas guilhermeleobas added the discussion An issue requiring discussion label Aug 4, 2023
@KrisMinchev
Copy link
Contributor Author

The current implementation does not work as intended. It is hard to keep the behaviour of the NumPy function np.diag_indices(n, ndim) since the size of the tuple being is equal to ndim and hence cannot be determined at compile time. Since np.diag_indices_from() calls np.diag_indices(), the same issue is present there as well.

@gmarkall gmarkall added this to the 0.59.0-rc1 milestone Aug 8, 2023
@guilhermeleobas
Copy link
Collaborator

In today's public meeting (notes), it was decided to support diag_indices with ndim as a literal argument.

Kristian, do you have any other question regarding supporting this function?

@guilhermeleobas guilhermeleobas removed the discussion An issue requiring discussion label Aug 8, 2023
@guilhermeleobas guilhermeleobas added 2 - In Progress and removed 4 - Waiting on author Waiting for author to respond to review labels Aug 28, 2023
@guilhermeleobas guilhermeleobas added abandoned PR is abandoned (no reason required) and removed 2 - In Progress labels Sep 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
abandoned PR is abandoned (no reason required) Effort - medium Medium size effort needed numpy
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants