From 9d61d5e1a10c1e0072bcf8d9a53bce440a258fe6 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 1 Aug 2022 10:10:33 +0530 Subject: [PATCH] MAINT: simplify axis check in array_api.vecdot --- numpy/array_api/linalg.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index d214046effd3..e0bf3bd390fa 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -379,16 +379,14 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) x1_ = np.moveaxis(x1_, axis, -1) x2_ = np.moveaxis(x2_, axis, -1) + if x1_shape[-1] != x2_shape[-1]: + raise ValueError("x1 and x2 must have the same size along the given axis") + res = x1_[..., None, :] @ x2_[..., None] return Array._new(res[..., 0, 0])