-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
With reference to issue #3663 #5827
Conversation
It is showing invalid syntax at axis=0, I am unable to rectify the issue. Can someone help me on it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Saanidhyavats. Could you add a test to dask/array/tests/test_slicing.py
to demonstrate this function gives the same result as NumPy's take_along_axis
function? I'd recommend looking at existing tests in the same module to get a sense for how to go about this (e.g. test_empty_list
)
dask/array/slicing.py
Outdated
@@ -1240,3 +1240,38 @@ def cached_cumsum(seq, initial_zero=False): | |||
if not initial_zero: | |||
result = result[1:] | |||
return result | |||
|
|||
|
|||
def take_along_axis(arr,indices,axis=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd like to match NumPy's signature for this function, which doesn't have any default values
def take_along_axis(arr,indices,axis=None): | |
def take_along_axis(arr,indices,axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing the none portion I think that if axis is None:
is also not required but this code line is present in numpy, shall I remove both the things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think instead of putting if axis is None
we can put if axis ==0
and remove the axis=None
from line 1245. Is it correct?
dask/array/slicing.py
Outdated
raiseIndexError('indices must be an integer array') | ||
if axis is None: | ||
arr = arr.flatten | ||
arr_shape = (np.prod((arr.shape),) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line has a mismatch in the number of open, (
, and close, )
, parenthesis. That's what is leading to the SyntaxError
in the CI builds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I will look into it.
dask/array/slicing.py
Outdated
""" | ||
a =np.array([[10, 30, 20], [60, 40, 50]]) | ||
ai=np.array([[0, 2, 1],[1, 2, 0]]) | ||
y=take_along_axis(a,ai,axis=0) | ||
print(y.compute()) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be here? I suspect it can be removed
@jrbourbeau and @jakirkham I just wanted to confirm whether the suggested changes are correct or not? |
I'll defer to @jrbourbeau 🙂 (Thanks btw for working on this @Saanidhyavats! 😄) |
I was referring |
@jakirkham I would actually usually defer to you for dask array problems like this :)
@Saanidhyavats I encourage you to read through our development guidelines here: https://docs.dask.org/en/latest/develop.html They should have information about tests, style, and so on. |
Thank you @mrocklin , I will look into it |
IIUC Saanidhya was asking about James' review comments. Hence why I deferred to him 😉 |
I have updated the test file. What else changes are needed in the file to merge the request. |
2 of the tests are failing, @jrbourbeau and @mrocklin can you tell me on how to check these test locally? |
@Saanidhyavats see https://docs.dask.org/en/latest/develop.html#run-tests. You also have some linting errors. I'd recommend installing pre-commit: https://docs.dask.org/en/latest/develop.html#code-formatting. It would have caught the first CI failure I see. |
Ok, I will look into it. |
On running pytest I am getting : |
Co-Authored-By: Tom Augspurger <TomAugspurger@users.noreply.github.com>
I am thinking to import |
I think the function is now in working condition. |
I apologize for the lack of feedback here @Saanidhyavats . Looking at the actual algorithm that you've written down it looks like you're writing this in a way that makes sense for Numpy arrays, but not for parallel Dask arrays. When writing algorithms for Dask we can't do things like iterate through all of the elements, or assume that the result will fit into a concrete numpy array. If you want to learn more about constructing Dask array algorithms I encourage you to read through https://docs.dask.org/en/latest/array-design.html . In the mean time I am going to close this PR. My apologies that it took this long. |
black dask
/flake8 dask
Added a new function which is similar to numpys
take_along_axis
(related to issue Add NumPy's new take_along_axis #3663)