We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using the numpy.array_api.vecdot I get wrong results.
numpy.array_api.vecdot
According to the specification the vecdot collapses the axis dimension to have N-1 dimensions and performs dot-products along the specified axis, thus removing that axis: https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#vecdot-x1-x2-axis-1
vecdot
axis
N-1
However, by using tensordot it extends the axis for the non-contracted axes.
tensordot
import numpy as np # recreated implementation (see `numpy/array_api/linalg.py`) def vecdot(x1, x2, axis=-1): return np.tensordot(x1, x2, axes=((axis,), (axis,))) a = np.random.rand(9, 10, 11) b = np.random.rand(9, 10, 11) c = vecdot(a, b, axis=1) d = np.einsum('ijk,ijk->ik', a, b) assert d.shape == (9, 11) # This will fail, the output shape will be (9, 11, 9, 11) assert c.shape == (9, 11), c.shape
Traceback (most recent call last): File "/home/nicpa/test.py", line 13, in <module> assert c.shape == (9, 11), c.shape AssertionError: (9, 11, 9, 11)
1.21.5 3.9.9
The text was updated successfully, but these errors were encountered:
@charris I think this can be closed. It was fixed in #21928.
Sorry, something went wrong.
Thanks for the heads up @AnirudhDagar .
No branches or pull requests
Describe the issue:
Using the
numpy.array_api.vecdot
I get wrong results.According to the specification the
vecdot
collapses theaxis
dimension to haveN-1
dimensions and performs dot-products along the specified axis, thus removing that axis:https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#vecdot-x1-x2-axis-1
However, by using
tensordot
it extends the axis for the non-contracted axes.Reproduce the code example:
Error message:
NumPy/Python version information:
1.21.5
3.9.9
The text was updated successfully, but these errors were encountered: