-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add axis support to np.take #9297
base: main
Are you sure you want to change the base?
Add axis support to np.take #9297
Conversation
numba/np/arrayobj.py
Outdated
if kind == 'getitem': | ||
fn = ''' | ||
@register_jitable | ||
def _getitem(a, idx, axis): | ||
if axis == 0: | ||
return a[idx, ...] | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be rewritten as:
def _get_take_getitem(ndim):
@register_jitable
def _take_getitem(a, idx, axis):
if axis == 0:
return a[idx, ...]
elif ndim >= 2 and axis == 1:
return a[:, idx, ...]
elif ndim >= 3 and axis == 2:
return a[:, :, idx, ...]
elif ndim >= 4 and axis == 3:
return a[:, :, :, idx, ...]
elif ndim >= 5 and axis == 4:
return a[:, :, :, :, idx, ...]
elif ndim >= 6 and axis == 2:
return a[:,:,:,:,:, idx, ...]
elif ndim >= 7 and axis == 2:
return a[:,:,:,:,:,:, idx, ...]
elif ndim >= 8 and axis == 2:
return a[:,:,:,:,:,:,:, idx, ...]
elif ndim >= 9 and axis == 2:
return a[:, :, :, :, :, :, :, idx, ...]
elif ndim >= 11 and axis == 10:
...
elif ndim >= 12 and axis == 10:
elif ndim >= 13 and axis == 10:
elif ndim >= 14 and axis == 10:
elif ndim >= 15 and axis == 10:
elif ndim >= 16 and axis == 10:
elif ndim >= 17 and axis == 10:
elif ndim >= 18 and axis == 10:
elif ndim >= 19 and axis == 10:
elif ndim >= 20 and axis == 20:
elif ndim >= 21 and axis == 20:
elif ndim >= 22 and axis == 20:
elif ndim >= 23 and axis == 20:
elif ndim >= 24 and axis == 20:
elif ndim >= 25 and axis == 20:
elif ndim >= 26 and axis == 20:
elif ndim >= 27 and axis == 20:
elif ndim >= 28 and axis == 20:
elif ndim >= 29 and axis == 20:
elif ndim >= 30 and axis == 30:
elif ndim >= 31 and axis == 31:
elif ndim >= 32 and axis == 32:
return _take_getitem
Dead branch prune would take care to remove branches that are dead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string itself should not contain a register_jitable
, the output function of this string should be jitted instead.
3bd19d6
to
72d2b83
Compare
numba/np/arrayobj.py
Outdated
if kind == 'getitem': | ||
fn = ''' | ||
@register_jitable | ||
def _getitem(a, idx, axis): | ||
if axis == 0: | ||
return a[idx, ...] | ||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string itself should not contain a register_jitable
, the output function of this string should be jitted instead.
numba/np/arrayobj.py
Outdated
tup = tuple(t) | ||
j = 0 | ||
for s in r.shape: | ||
if s != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The squeeze implementation needs to exclude the axis based on the axis provided not the value 1
:
For instance:
import time
from numba import njit
import numpy as np
@njit
def foo(a, indices, axis):
return np.take(a, indices, axis)
def measure_time(fn, *args):
start = time.time()
res = fn(*args)
end = time.time()
print(f'{fn.__name__}: {end - start}s')
print(res.shape)
print("ndim:", res.ndim)
shape = (10, 1, 11, 1, 12, 1, 13)
a = np.ones(shape)
indices = 0
axis=1
measure_time(foo.py_func, a, indices, axis)
# foo: 0.00011563301086425781s
# (10, 11, 1, 12, 1, 13)
# ndim: 6
measure_time(foo, a, indices, axis) # fails
shape = (10, 1, 11, 1, 12, 1, 13)
a = np.ones(shape)
indices = 0
axis = 2
measure_time(foo.py_func, a, indices, axis)
# foo: 6.4373016357421875e-06s
# (10, 1, 1, 12, 1, 13)
# ndim: 6
measure_time(foo, a, indices, axis) # fails
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I'll fix this
out = np.empty(shape, dtype=a.dtype) | ||
for i in range(len(indices)): | ||
y = _getitem(a, indices[i], axis) | ||
_setitem(out, i, axis, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is potentially very wasteful. If we're anyways generating the getitem
and setitem
specifically for this function we may want to combine them into a single function that copies the elements directly in the out
resultant. Is that something you've tried out ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've split getitem from setitem because the latter is used in #9313
CI failure is not related to this PR. |
I've given the CI a prod. |
Non-scientific script to measure time: