From 55b8fb0fd7c0f08d528afcf05e5e2307e61491b4 Mon Sep 17 00:00:00 2001 From: Athan Date: Mon, 5 Sep 2022 10:19:38 -0700 Subject: [PATCH] Clarify broadcasting behavior in `vecdot` (#473) This resolves https://github.com/data-apis/array-api/issues/471. The existing spec provides conflicting guidance saying both that the axes over which to compute the dot product must be equal, while also saying that input arrays must be broadcast compatible without qualification, thus implying that the contracted axis could also broadcast. This commit explicitly defines broadcast behavior for only the contracted axes, thus bringing vecdot inline with broadcasting behavior in tensordot. --- .../array_api/linear_algebra_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spec/API_specification/array_api/linear_algebra_functions.py b/spec/API_specification/array_api/linear_algebra_functions.py index 275d91856..1284c9c4e 100644 --- a/spec/API_specification/array_api/linear_algebra_functions.py +++ b/spec/API_specification/array_api/linear_algebra_functions.py @@ -92,14 +92,18 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: x1: array first input array. Should have a real-valued data type. x2: array - second input array. Must be compatible with ``x1`` (see :ref:`broadcasting`). Should have a real-valued data type. + second input array. Should have a real-valued data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal. + + .. note:: + Contracted axes (dimensions) must not be broadcasted. + axis:int axis over which to compute the dot product. Must be an integer on the interval ``[-N, N)``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. If specified as a negative integer, the function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``. Returns ------- out: array - if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting`. The returned array must have a data type determined by :ref:`type-promotion`. + if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array must have a data type determined by :ref:`type-promotion`. **Raises**