Skip to content

Commit

Permalink
Use moveaxis + matmul instead of einsum in vecdot
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed Jul 7, 2022
1 parent ddee5d9 commit 285084a
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions numpy/array_api/linalg.py
Expand Up @@ -384,15 +384,13 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
if x1_shape[axis] != x2_shape[axis]:
raise ValueError("x1 and x2 must have the same size along the given axis")
if axis < 0:
axis += ndim

x1_, x2_ = np.broadcast_arrays(x1._array, x2._array)
in_indices = list(range(ndim))
out_indices = list(range(ndim))
out_indices.pop(axis)
x1_ = np.moveaxis(x1_, axis, -1)
x2_ = np.moveaxis(x2_, axis, -1)

return Array._new(np.einsum(x1_, in_indices, x2_, in_indices, out_indices))
res = x1_[..., None, :] @ x2_[..., None]
return Array._new(res[..., 0, 0])


# Note: the name here is different from norm(). The array API norm is split
Expand Down

0 comments on commit 285084a

Please sign in to comment.