From 6cc41836e1d2d12eea8591fe5252ca484f54d6af Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 7 Sep 2022 11:13:22 -0500 Subject: [PATCH] BUG: Fix the implementation of numpy.array_api.vecdot (#21928) * Fix the implementation of numpy.array_api.vecdot See https://data-apis.org/array-api/latest/API_specification/generated/signatures.linear_algebra_functions.vecdot.html * Use moveaxis + matmul instead of einsum in vecdot --- numpy/array_api/linalg.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index f422e1c2767f..335c2b13ba73 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -376,7 +376,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') - return tensordot(x1, x2, axes=((axis,), (axis,))) + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) # Note: the name here is different from norm(). The array API norm is split