New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for np.diag_indices() and np.diag_indices_from() #9115
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @KrisMinchev, great work on adding these functions to Numba.
Note that the NumPy versions return tuples whereas the implementations presented here return NumPy arrays.
Numba tries to keep compatibility with NumPy as close as possible. I've left a suggestion on how you can return tuples instead of NumPy arrays. The idea is to create a tuple at compile-time with the same type as the tuple you will return at runtime, and use tuple_setitem
to set each tuple item.
numba/np/arrayobj.py
Outdated
|
||
def impl(n, ndim=2): | ||
res = np.arange(n * ndim) | ||
for i in range(n * ndim): | ||
res[i] = res[i] % n | ||
return res.reshape((ndim, n)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numba has a function called tuple_setitem
from cpython/unsafe/tuple.py
where you can use to set items in a tuple at runtime.
def impl(n, ndim=2): | |
res = np.arange(n * ndim) | |
for i in range(n * ndim): | |
res[i] = res[i] % n | |
return res.reshape((ndim, n)) | |
# tup_init is used only to get the correct tuple type | |
tup_init = (np.arange(1),) * 2 | |
from numba.cpython.unsafe.tuple import tuple_setitem | |
def impl(n, ndim=2): | |
res = np.arange(n * ndim) | |
for i in range(n * ndim): | |
res[i] = res[i] % n | |
x = res.reshape((ndim, n)) | |
tup = tup_init | |
for i in range(n): | |
tup = tuple_setitem(tup, i, x[i]) | |
return tup |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @guilhermeleobas, thank you for your suggestion. The only issue with this implementation is that the array elements that are being returned are immutable, which is different to what the NumPy function np.diagflat
returns:
import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
res = np.diag_indices(n, ndim)
return res
a=np.diag_indices(5)
print(a)
a[0][0] = 5
print(a)
c=foodiagindices(5)
print(c)
c[0][0] = 5
Returns
(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
(array([5, 1, 2, 3, 4]), array([5, 1, 2, 3, 4]))
(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
Traceback (most recent call last):
File "/home/kristian/Desktop/numba/a.py", line 13, in
c[0][0] = 5
ValueError: assignment destination is read-only
Another issue is that your proposed implementation implicitly assumes that ndim
is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2
. Hence, the following code:
import numpy as np
from numba import njit
@njit
def foodiagindices(n, ndim=2):
res = np.diag_indices(n, ndim)
return res
c = foodiagindices(5, 3)
throws a segmentation fault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another issue is that your proposed implementation implicitly assumes that ndim is equal to 2 at runtime by setting tup_init = (np.arange(1),) * 2. Hence, the following code
Oh, I know. My code was just an example on how to use tuple_setitem
The current implementation does not work as intended. It is hard to keep the behaviour of the NumPy function |
In today's public meeting (notes), it was decided to support Kristian, do you have any other question regarding supporting this function? |
This PR improves on #4074 by implementing the functions
np.diag_indices()
andnp.diag_indices_from()
. Note that the NumPy versions return tuples whereas the implementations presented here return NumPy arrays.