diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index ef4f719a..c11d5732 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -7,9 +7,9 @@ # These are exported here so that they can be included in the special cases # tests from this file. from ._array_module import logical_not, subtract, floor, ceil, where +from . import _array_module as xp from . import dtype_helpers as dh - __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil', 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN', @@ -164,7 +164,7 @@ def notequal(x, y): return not_equal(x, y) -def assert_exactly_equal(x, y): +def assert_exactly_equal(x, y, msg_extra=None): """ Test that the arrays x and y are exactly equal. @@ -172,11 +172,13 @@ def assert_exactly_equal(x, y): equal. """ - assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})" + extra = '' if not msg_extra else f' ({msg_extra})' + + assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}" - assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})" + assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}" - assert all(exactly_equal(x, y)), "The input arrays have different values" + assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}" def assert_finite(x): """ @@ -306,3 +308,13 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" +def _matrix_transpose(x): + if not isinstance(xp.matrix_transpose, xp._UndefinedStub): + return xp.matrix_transpose(x) + if hasattr(x, 'mT'): + return x.mT + if not isinstance(xp.permute_dims, xp._UndefinedStub): + perm = list(range(x.ndim)) + perm[-1], perm[-2] = perm[-2], perm[-1] + return xp.permute_dims(x, axes=tuple(perm)) + raise NotImplementedError("No way to compute matrix transpose") diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 3052d54f..59edfe86 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -231,6 +231,57 @@ class MinMax(NamedTuple): {"complex64": xp.float32, "complex128": xp.float64} ) +def as_real_dtype(dtype): + """ + Return the corresponding real dtype for a given floating-point dtype. + """ + if dtype in real_float_dtypes: + return dtype + elif dtype_to_name[dtype] in complex_names: + return dtype_components[dtype] + else: + raise ValueError("as_real_dtype requires a floating-point dtype") + +def accumulation_result_dtype(x_dtype, dtype_kwarg): + """ + Result dtype logic for sum(), prod(), and trace() + + Note: may return None if a default uint cannot exist (e.g., for pytorch + which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106 + + """ + if dtype_kwarg is None: + if is_int_dtype(x_dtype): + if x_dtype in uint_dtypes: + default_dtype = default_uint + else: + default_dtype = default_int + if default_dtype is None: + _dtype = None + else: + m, M = dtype_ranges[x_dtype] + d_m, d_M = dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x_dtype + else: + _dtype = default_dtype + elif is_float_dtype(x_dtype, include_complex=False): + if dtype_nbits[x_dtype] > dtype_nbits[default_float]: + _dtype = x_dtype + else: + _dtype = default_float + elif api_version > "2021.12": + # Complex dtype + if dtype_nbits[x_dtype] > dtype_nbits[default_complex]: + _dtype = x_dtype + else: + _dtype = default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") + else: + _dtype = dtype_kwarg + + return _dtype if not hasattr(xp, "asarray"): default_int = xp.int32 diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3864d426..3033dac3 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -12,6 +12,7 @@ sampled_from, shared, builds) from . import _array_module as xp, api_version +from . import array_helpers as ah from . import dtype_helpers as dh from . import shape_helpers as sh from . import xps @@ -211,6 +212,7 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False) # Use this to avoid memory errors with NumPy. # See https://github.com/numpy/numpy/issues/15753 +# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2) def shapes(**kw): kw.setdefault('min_dims', 0) kw.setdefault('min_side', 0) @@ -280,16 +282,19 @@ def mutually_broadcastable_shapes( # Note: This should become hermitian_matrices when complex dtypes are added @composite -def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True): +def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.): shape = draw(square_matrix_shapes) dtype = draw(dtypes) if not isinstance(finite, bool): finite = draw(finite) elements = {'allow_nan': False, 'allow_infinity': False} if finite else None a = draw(arrays(dtype=dtype, shape=shape, elements=elements)) - upper = xp.triu(a) - lower = xp.triu(a, k=1).mT - return upper + lower + at = ah._matrix_transpose(a) + H = (a + at)*0.5 + if finite: + assume(not xp.any(xp.isinf(H))) + assume(xp.all((H == 0.) | ((1/bound <= xp.abs(H)) & (xp.abs(H) <= bound)))) + return H @composite def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()): @@ -297,8 +302,9 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()): # TODO: Generate arbitrary positive definite matrices, for instance, by # using something like # https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351. - n = draw(integers(0)) - shape = draw(shapes()) + (n, n) + base_shape = draw(shapes()) + n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value + shape = base_shape + (n, n) assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE) dtype = draw(dtypes) return broadcast_to(eye(n, dtype=dtype), shape) @@ -308,12 +314,18 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes( # For now, just generate stacks of diagonal matrices. n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) stack_shape = draw(stack_shapes) - d = draw(arrays(dtypes, shape=(*stack_shape, 1, n), - elements=dict(allow_nan=False, allow_infinity=False))) + dtype = draw(dtypes) + elements = one_of( + from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False), + from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False), + ) + d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements)) + # Functions that require invertible matrices may do anything when it is # singular, including raising an exception, so we make sure the diagonals # are sufficiently nonzero to avoid any numerical issues. - assume(xp.all(xp.abs(d) > 0.5)) + assert xp.all(xp.abs(d) >= 0.5) + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) return xp.where(diag_mask, d, xp.zeros_like(d)) diff --git a/array_api_tests/meta/test_linalg.py b/array_api_tests/meta/test_linalg.py new file mode 100644 index 00000000..a4171e81 --- /dev/null +++ b/array_api_tests/meta/test_linalg.py @@ -0,0 +1,16 @@ +import pytest + +from hypothesis import given + +from ..hypothesis_helpers import symmetric_matrices +from .. import array_helpers as ah +from .. import _array_module as xp + +@pytest.mark.xp_extension('linalg') +@given(x=symmetric_matrices(finite=True)) +def test_symmetric_matrices(x): + upper = xp.triu(x) + lower = xp.tril(x) + lowerT = ah._matrix_transpose(lower) + + ah.assert_exactly_equal(upper, lowerT) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 18395a20..35ff1d42 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -12,23 +12,26 @@ required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25). """ -# TODO: test with complex dtypes where appropiate - import pytest from hypothesis import assume, given -from hypothesis.strategies import (booleans, composite, none, tuples, integers, - shared, sampled_from, one_of, data, just) +from hypothesis.strategies import (booleans, composite, tuples, floats, + integers, shared, sampled_from, one_of, + data) from ndindex import iter_indices +import itertools + from .array_helpers import assert_exactly_equal, asarray -from .hypothesis_helpers import (arrays, xps, shapes, kwargs, matrix_shapes, - square_matrix_shapes, symmetric_matrices, +from .hypothesis_helpers import (arrays, all_floating_dtypes, xps, shapes, + kwargs, matrix_shapes, square_matrix_shapes, + symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes, one_d_shapes, two_mutually_broadcastable_shapes, + mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE, finite_matrices, - rtol_shared_matrix_shapes, rtols) + rtol_shared_matrix_shapes, rtols, axes) from . import dtype_helpers as dh from . import pytest_helpers as ph from . import shape_helpers as sh @@ -39,12 +42,25 @@ pytestmark = pytest.mark.ci -# Standin strategy for not yet implemented tests -todo = none() +def assert_equal(x, y, msg_extra=None): + extra = '' if not msg_extra else f' ({msg_extra})' + if x.dtype in dh.all_float_dtypes: + # It's too difficult to do an approximately equal test here because + # different routines can give completely different answers, and even + # when it does work, the elementwise comparisons are too slow. So for + # floating-point dtypes only test the shape and dtypes. + + # assert_allclose(x, y) + + assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}" + assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}" + else: + assert_exactly_equal(x, y, msg_extra=msg_extra) def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), - assert_equal=assert_exactly_equal, **kw): + res_axes=None, + assert_equal=assert_equal, **kw): """ Test that f(*args, **kw) maps across stacks of matrices @@ -68,7 +84,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, # Assume the result is stacked along the last 'dims' axes of matrix_axes. # This holds for all the functions tested in this file - res_axes = matrix_axes[::-1][:dims] + if res_axes is None: + if not isinstance(matrix_axes, tuple) and all(isinstance(x, int) for x in matrix_axes): + raise ValueError("res_axes must be specified if matrix_axes is not a tuple of integers") + res_axes = matrix_axes[::-1][:dims] for (x_idxes, (res_idx,)) in zip( iter_indices(*shapes, skip_axes=matrix_axes), @@ -79,9 +98,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, res_stack = res[res_idx] x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] decomp_res_stack = f(*x_stacks, **kw) - assert_equal(res_stack, decomp_res_stack) + msg_extra = f'{x_idxes = }, {res_idx = }' + assert_equal(res_stack, decomp_res_stack, msg_extra) if true_val: - assert_equal(decomp_res_stack, true_val(*x_stacks)) + assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra) def _test_namedtuple(res, fields, func_name): """ @@ -92,6 +112,7 @@ def _test_namedtuple(res, fields, func_name): # a tuple subclass with the right fields in the right order. assert isinstance(res, tuple), f"{func_name}() did not return a tuple" + assert type(res) != tuple, f"{func_name}() did not return a namedtuple" assert len(res) == len(fields), f"{func_name}() result tuple not the correct length (should have {len(fields)} elements)" for i, field in enumerate(fields): assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field" @@ -105,8 +126,9 @@ def _test_namedtuple(res, fields, func_name): def test_cholesky(x, kw): res = linalg.cholesky(x, **kw) - assert res.shape == x.shape, "cholesky() did not return the correct shape" - assert res.dtype == x.dtype, "cholesky() did not return the correct dtype" + ph.assert_dtype("cholesky", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("cholesky", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape) _test_stacks(linalg.cholesky, x, **kw, res=res) @@ -126,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes): in the drawn axis. """ - shape = list(draw(shapes())) - size = len(shape) - assume(size > 0) + shape1, shape2 = draw(two_mutually_broadcastable_shapes) + min_ndim = min(len(shape1), len(shape2)) + assume(min_ndim > 0) - kw = draw(kwargs(axis=integers(-size, size-1))) + kw = draw(kwargs(axis=integers(-min_ndim, -1))) axis = kw.get('axis', -1) - shape[axis] = 3 - shape = tuple(shape) + if draw(booleans()): + # Sometimes allow invalid inputs to test it errors + shape1 = list(shape1) + shape1[axis] = 3 + shape1 = tuple(shape1) + shape2 = list(shape2) + shape2[axis] = 3 + shape2 = tuple(shape2) mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects)) arrays1 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[0]), - shape=shape, + shape=shape1, ) arrays2 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[1]), - shape=shape, + shape=shape2, ) return draw(arrays1), draw(arrays2), kw @@ -154,15 +182,18 @@ def test_cross(x1_x2_kw): x1, x2, kw = x1_x2_kw axis = kw.get('axis', -1) - err = "test_cross produced invalid input. This indicates a bug in the test suite." - assert x1.shape == x2.shape, err - shape = x1.shape - assert x1.shape[axis] == x2.shape[axis] == 3, err + if not (x1.shape[axis] == x2.shape[axis] == 3): + ph.raises(Exception, lambda: xp.cross(x1, x2, **kw), + "cross did not raise an exception for invalid shapes") + return res = linalg.cross(x1, x2, **kw) - assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype" - assert res.shape == shape, "cross() did not return the correct shape" + broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) + + ph.assert_dtype("cross", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) + ph.assert_result_shape("cross", in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=broadcasted_shape) def exact_cross(a, b): assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." @@ -179,13 +210,14 @@ def exact_cross(a, b): @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes), + x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes), ) def test_det(x): res = linalg.det(x) - assert res.dtype == x.dtype, "det() did not return the correct dtype" - assert res.shape == x.shape[:-2], "det() did not return the correct shape" + ph.assert_dtype("det", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("det", in_shapes=[x.shape], out_shape=res.shape, + expected=x.shape[:-2]) _test_stacks(linalg.det, x, res=res, dims=0) @@ -193,7 +225,7 @@ def test_det(x): @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), + x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -201,7 +233,7 @@ def test_det(x): def test_diagonal(x, kw): res = linalg.diagonal(x, **kw) - assert res.dtype == x.dtype, "diagonal() returned the wrong dtype" + ph.assert_dtype("diagonal", in_dtype=x.dtype, out_dtype=res.dtype) n, m = x.shape[-2:] offset = kw.get('offset', 0) @@ -215,9 +247,11 @@ def test_diagonal(x, kw): else: diag_size = min(n, m, max(m - offset, 0)) - assert res.shape == (*x.shape[:-2], diag_size), "diagonal() returned the wrong shape" + expected_shape = (*x.shape[:-2], diag_size) + ph.assert_result_shape("diagonal", in_shapes=[x.shape], + out_shape=res.shape, expected=expected_shape) - def true_diag(x_stack): + def true_diag(x_stack, offset=0): if offset >= 0: x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: @@ -226,7 +260,6 @@ def true_diag(x_stack): _test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag) -@pytest.mark.skip(reason="Inputs need to be restricted") # TODO @pytest.mark.xp_extension('linalg') @given(x=symmetric_matrices(finite=True)) def test_eigh(x): @@ -237,14 +270,28 @@ def test_eigh(x): eigenvalues = res.eigenvalues eigenvectors = res.eigenvectors - assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype" - assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape" - - assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype" - assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape" - + ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvalues.dtype, + expected=x.dtype, repr_name="eigenvalues.dtype") + ph.assert_result_shape("eigh", in_shapes=[x.shape], + out_shape=eigenvalues.shape, + expected=x.shape[:-1], + repr_name="eigenvalues.shape") + + ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvectors.dtype, + expected=x.dtype, repr_name="eigenvectors.dtype") + ph.assert_result_shape("eigh", in_shapes=[x.shape], + out_shape=eigenvectors.shape, expected=x.shape, + repr_name="eigenvectors.shape") + + # Note: _test_stacks here is only testing the shape and dtype. The actual + # eigenvalues and eigenvectors may not be equal at all, since there is not + # requirements about how eigh computes an eigenbasis, or about the order + # of the eigenvalues _test_stacks(lambda x: linalg.eigh(x).eigenvalues, x, res=eigenvalues, dims=1) + + # TODO: Test that eigenvectors are orthonormal. + _test_stacks(lambda x: linalg.eigh(x).eigenvectors, x, res=eigenvectors, dims=2) @@ -256,33 +303,37 @@ def test_eigh(x): def test_eigvalsh(x): res = linalg.eigvalsh(x) - assert res.dtype == x.dtype, "eigvalsh() did not return the correct dtype" - assert res.shape == x.shape[:-1], "eigvalsh() did not return the correct shape" + ph.assert_dtype("eigvalsh", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("eigvalsh", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape[:-1]) + # Note: _test_stacks here is only testing the shape and dtype. The actual + # eigenvalues may not be equal at all, since there is not requirements or + # about the order of the eigenvalues, and the stacking code may use a + # different code path. _test_stacks(linalg.eigvalsh, x, res=res, dims=1) # TODO: Should we test that the result is the same as eigh(x).eigenvalues? + # (probably no because the spec doesn't actually require that) # TODO: Test that res actually corresponds to the eigenvalues of x -@pytest.mark.skip(reason="flaky") @pytest.mark.xp_extension('linalg') @given(x=invertible_matrices()) def test_inv(x): res = linalg.inv(x) - assert res.shape == x.shape, "inv() did not return the correct shape" - assert res.dtype == x.dtype, "inv() did not return the correct dtype" + ph.assert_dtype("inv", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("inv", in_shapes=[x.shape], out_shape=res.shape, + expected=x.shape) _test_stacks(linalg.inv, x, res=res) # TODO: Test that the result is actually the inverse -@pytest.mark.skip(reason="flaky") -@given( - *two_mutual_arrays(dh.real_dtypes) -) -def test_matmul(x1, x2): +def _test_matmul(namespace, x1, x2): + matmul = namespace.matmul + # TODO: Make this also test the @ operator if (x1.shape == () or x2.shape == () or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape @@ -295,24 +346,43 @@ def test_matmul(x1, x2): "matmul did not raise an exception for invalid shapes") return else: - res = _array_module.matmul(x1, x2) + res = matmul(x1, x2) ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) if len(x1.shape) == len(x2.shape) == 1: - assert res.shape == () + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=()) elif len(x1.shape) == 1: - assert res.shape == x2.shape[:-2] + x2.shape[-1:] - _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, + expected=x2.shape[:-2] + x2.shape[-1:]) + _test_stacks(matmul, x1, x2, res=res, dims=1, + matrix_axes=[(0,), (-2, -1)], res_axes=[-1]) elif len(x2.shape) == 1: - assert res.shape == x1.shape[:-1] - _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=x1.shape[:-1]) + _test_stacks(matmul, x1, x2, res=res, dims=1, + matrix_axes=[(-2, -1), (0,)], res_axes=[-1]) else: stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) - assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1]) - _test_stacks(_array_module.matmul, x1, x2, res=res) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, + expected=stack_shape + (x1.shape[-2], x2.shape[-1])) + _test_stacks(matmul, x1, x2, res=res) + +@pytest.mark.xp_extension('linalg') +@given( + *two_mutual_arrays(dh.real_dtypes) +) +def test_linalg_matmul(x1, x2): + return _test_matmul(linalg, x1, x2) -matrix_norm_shapes = shared(matrix_shapes()) +@given( + *two_mutual_arrays(dh.real_dtypes) +) +def test_matmul(x1, x2): + return _test_matmul(_array_module, x1, x2) @pytest.mark.xp_extension('linalg') @given( @@ -331,26 +401,28 @@ def test_matrix_norm(x, kw): expected_shape = x.shape[:-2] + (1, 1) else: expected_shape = x.shape[:-2] - assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape" - assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype" + ph.assert_dtype("matrix_norm", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("matrix_norm", in_shapes=[x.shape], + out_shape=res.shape, expected=expected_shape) _test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0, res=res) -matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n') +matrix_power_n = shared(integers(-100, 100), key='matrix_power n') @pytest.mark.xp_extension('linalg') @given( # Generate any square matrix if n >= 0 but only invertible matrices if n < 0 x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else - arrays(dtype=xps.floating_dtypes(), + arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes)), n=matrix_power_n, ) def test_matrix_power(x, n): res = linalg.matrix_power(x, n) - assert res.shape == x.shape, "matrix_power() did not return the correct shape" - assert res.dtype == x.dtype, "matrix_power() did not return the correct dtype" + ph.assert_dtype("matrix_power", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("matrix_power", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape) if n == 0: true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype) @@ -368,11 +440,9 @@ def test_matrix_power(x, n): def test_matrix_rank(x, kw): linalg.matrix_rank(x, **kw) -@given( - x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), -) -def test_matrix_transpose(x): - res = _array_module.matrix_transpose(x) +def _test_matrix_transpose(namespace, x): + matrix_transpose = namespace.matrix_transpose + res = matrix_transpose(x) true_val = lambda a: _array_module.asarray([[a[i, j] for i in range(a.shape[0])] for j in range(a.shape[1])], @@ -380,10 +450,24 @@ def test_matrix_transpose(x): shape = list(x.shape) shape[-1], shape[-2] = shape[-2], shape[-1] shape = tuple(shape) - assert res.shape == shape, "matrix_transpose() did not return the correct shape" - assert res.dtype == x.dtype, "matrix_transpose() did not return the correct dtype" + ph.assert_dtype("matrix_transpose", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("matrix_transpose", in_shapes=[x.shape], + out_shape=res.shape, expected=shape) + + _test_stacks(matrix_transpose, x, res=res, true_val=true_val) + +@pytest.mark.xp_extension('linalg') +@given( + x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()), +) +def test_linalg_matrix_transpose(x): + return _test_matrix_transpose(linalg, x) - _test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val) +@given( + x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()), +) +def test_matrix_transpose(x): + return _test_matrix_transpose(_array_module, x) @pytest.mark.xp_extension('linalg') @given( @@ -396,8 +480,9 @@ def test_outer(x1, x2): res = linalg.outer(x1, x2) shape = (x1.shape[0], x2.shape[0]) - assert res.shape == shape, "outer() did not return the correct shape" - assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "outer() did not return the correct dtype" + ph.assert_dtype("outer", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) + ph.assert_result_shape("outer", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=shape) if 0 in shape: true_res = _array_module.empty(shape, dtype=res.dtype) @@ -419,7 +504,7 @@ def test_pinv(x, kw): @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.floating_dtypes(), shape=matrix_shapes()), + x=arrays(dtype=all_floating_dtypes(), shape=matrix_shapes()), kw=kwargs(mode=sampled_from(['reduced', 'complete'])) ) def test_qr(x, kw): @@ -433,17 +518,23 @@ def test_qr(x, kw): Q = res.Q R = res.R - assert Q.dtype == x.dtype, "qr().Q did not return the correct dtype" + ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=Q.dtype, + expected=x.dtype, repr_name="Q.dtype") if mode == 'complete': - assert Q.shape == x.shape[:-2] + (M, M), "qr().Q did not return the correct shape" + expected_Q_shape = x.shape[:-2] + (M, M) else: - assert Q.shape == x.shape[:-2] + (M, K), "qr().Q did not return the correct shape" + expected_Q_shape = x.shape[:-2] + (M, K) + ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=Q.shape, + expected=expected_Q_shape, repr_name="Q.shape") - assert R.dtype == x.dtype, "qr().R did not return the correct dtype" + ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=R.dtype, + expected=x.dtype, repr_name="R.dtype") if mode == 'complete': - assert R.shape == x.shape[:-2] + (M, N), "qr().R did not return the correct shape" + expected_R_shape = x.shape[:-2] + (M, N) else: - assert R.shape == x.shape[:-2] + (K, N), "qr().R did not return the correct shape" + expected_R_shape = x.shape[:-2] + (K, N) + ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=R.shape, + expected=expected_R_shape, repr_name="R.shape") _test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q) _test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R) @@ -455,7 +546,7 @@ def test_qr(x, kw): @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes), + x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes), ) def test_slogdet(x): res = linalg.slogdet(x) @@ -464,11 +555,19 @@ def test_slogdet(x): sign, logabsdet = res - assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype" - assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape" - assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype" - assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape" - + ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=sign.dtype, + expected=x.dtype, repr_name="sign.dtype") + ph.assert_result_shape("slogdet", in_shapes=[x.shape], + out_shape=sign.shape, + expected=x.shape[:-2], + repr_name="sign.shape") + expected_dtype = dh.as_real_dtype(x.dtype) + ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=logabsdet.dtype, + expected=expected_dtype, repr_name="logabsdet.dtype") + ph.assert_result_shape("slogdet", in_shapes=[x.shape], + out_shape=logabsdet.shape, + expected=x.shape[:-2], + repr_name="logabsdet.shape") _test_stacks(lambda x: linalg.slogdet(x).sign, x, res=sign, dims=0) @@ -510,13 +609,19 @@ def _x2_shapes(draw): return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) - x2 = arrays(dtype=xps.floating_dtypes(), shape=x2_shapes) + x2 = arrays(dtype=all_floating_dtypes(), shape=x2_shapes) return x1, x2 @pytest.mark.xp_extension('linalg') @given(*solve_args()) def test_solve(x1, x2): - linalg.solve(x1, x2) + res = linalg.solve(x1, x2) + + if x2.ndim == 1: + _test_stacks(linalg.solve, x1, x2, res=res, dims=1, + matrix_axes=[(-2, -1), (0,)], res_axes=[-1]) + else: + _test_stacks(linalg.solve, x1, x2, res=res, dims=2) @pytest.mark.xp_extension('linalg') @given( @@ -534,17 +639,31 @@ def test_svd(x, kw): U, S, Vh = res - assert U.dtype == x.dtype, "svd().U did not return the correct dtype" - assert S.dtype == x.dtype, "svd().S did not return the correct dtype" - assert Vh.dtype == x.dtype, "svd().Vh did not return the correct dtype" + ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=U.dtype, + expected=x.dtype, repr_name="U.dtype") + ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=S.dtype, + expected=x.dtype, repr_name="S.dtype") + ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=Vh.dtype, + expected=x.dtype, repr_name="Vh.dtype") if full_matrices: - assert U.shape == (*stack, M, M), "svd().U did not return the correct shape" - assert Vh.shape == (*stack, N, N), "svd().Vh did not return the correct shape" + expected_U_shape = (*stack, M, M) + expected_Vh_shape = (*stack, N, N) else: - assert U.shape == (*stack, M, K), "svd(full_matrices=False).U did not return the correct shape" - assert Vh.shape == (*stack, K, N), "svd(full_matrices=False).Vh did not return the correct shape" - assert S.shape == (*stack, K), "svd().S did not return the correct shape" + expected_U_shape = (*stack, M, K) + expected_Vh_shape = (*stack, K, N) + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=U.shape, + expected=expected_U_shape, + repr_name="U.shape") + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=Vh.shape, + expected=expected_Vh_shape, + repr_name="Vh.shape") + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=S.shape, + expected=(*stack, K), + repr_name="S.shape") # The values of s must be sorted from largest to smallest if K >= 1: @@ -564,8 +683,11 @@ def test_svdvals(x): *stack, M, N = x.shape K = min(M, N) - assert res.dtype == x.dtype, "svdvals() did not return the correct dtype" - assert res.shape == (*stack, K), "svdvals() did not return the correct shape" + ph.assert_dtype("svdvals", in_dtype=x.dtype, out_dtype=res.dtype, + expected=x.dtype) + ph.assert_result_shape("svdvals", in_shapes=[x.shape], + out_shape=res.shape, + expected=(*stack, K)) # SVD values must be sorted from largest to smallest assert _array_module.all(res[..., :-1] >= res[..., 1:]), "svdvals() values are not sorted from largest to smallest" @@ -574,26 +696,133 @@ def test_svdvals(x): # TODO: Check that svdvals() is the same as svd().s. +_tensordot_pre_shapes = shared(two_mutually_broadcastable_shapes) -@given( - dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), - shape=shapes(), - data=data(), -) -def test_tensordot(dtypes, shape, data): - # TODO: vary shapes, vary contracted axes, test different axes arguments - x1 = data.draw(arrays(dtype=dtypes[0], shape=shape), label="x1") - x2 = data.draw(arrays(dtype=dtypes[1], shape=shape), label="x2") +@composite +def _tensordot_axes(draw): + shape1, shape2 = draw(_tensordot_pre_shapes) + ndim1, ndim2 = len(shape1), len(shape2) + isint = draw(booleans()) + + if isint: + N = min(ndim1, ndim2) + return draw(integers(0, N)) + else: + if ndim1 < ndim2: + first = draw(xps.valid_tuple_axes(ndim1)) + second = draw(xps.valid_tuple_axes(ndim2, min_size=len(first), + max_size=len(first))) + else: + second = draw(xps.valid_tuple_axes(ndim2)) + first = draw(xps.valid_tuple_axes(ndim1, min_size=len(second), + max_size=len(second))) + return (tuple(first), tuple(second)) + +tensordot_kw = shared(kwargs(axes=_tensordot_axes())) + +@composite +def tensordot_shapes(draw): + _shape1, _shape2 = map(list, draw(_tensordot_pre_shapes)) + ndim1, ndim2 = len(_shape1), len(_shape2) + kw = draw(tensordot_kw) + if 'axes' not in kw: + assume(ndim1 >= 2 and ndim2 >= 2) + axes = kw.get('axes', 2) + + if isinstance(axes, int): + axes = [list(range(-axes, 0)), list(range(0, axes))] + + first, second = axes + for i, j in zip(first, second): + try: + if -ndim2 <= j < ndim2 and _shape2[j] != 1: + _shape1[i] = _shape2[j] + if -ndim1 <= i < ndim1 and _shape1[i] != 1: + _shape2[j] = _shape1[i] + except: + raise + + shape1, shape2 = map(tuple, [_shape1, _shape2]) + return (shape1, shape2) + +def _test_tensordot_stacks(x1, x2, kw, res): + """ + Variant of _test_stacks for tensordot + + tensordot doesn't stack directly along the non-contracted dimensions like + the other linalg functions. Rather, it is stacked along the product of + each non-contracted dimension. These dimensions are independent of one + another and do not broadcast. + """ + shape1, shape2 = x1.shape, x2.shape + + axes = kw.get('axes', 2) + + if isinstance(axes, int): + res_axes = axes + axes = [list(range(-axes, 0)), list(range(0, axes))] + else: + # Convert something like (0, 4, 2) into (0, 2, 1) + res_axes = [] + for a, s in zip(axes, [shape1, shape2]): + indices = [range(len(s))[i] for i in a] + repl = dict(zip(sorted(indices), range(len(indices)))) + res_axes.append(tuple(repl[i] for i in indices)) + + for ((i,), (j,)), (res_idx,) in zip( + itertools.product( + iter_indices(shape1, skip_axes=axes[0]), + iter_indices(shape2, skip_axes=axes[1])), + iter_indices(res.shape)): + i, j, res_idx = i.raw, j.raw, res_idx.raw + + res_stack = res[res_idx] + x1_stack = x1[i] + x2_stack = x2[j] + decomp_res_stack = xp.tensordot(x1_stack, x2_stack, axes=res_axes) + assert_equal(res_stack, decomp_res_stack) + +def _test_tensordot(namespace, x1, x2, kw): + tensordot = namespace.tensordot + res = tensordot(x1, x2, **kw) + + ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) - out = xp.tensordot(x1, x2, axes=len(shape)) + axes = _axes = kw.get('axes', 2) - ph.assert_dtype("tensordot", in_dtype=dtypes, out_dtype=out.dtype) - # TODO: assert shape and elements + if isinstance(axes, int): + _axes = [list(range(-axes, 0)), list(range(0, axes))] + + _shape1 = list(x1.shape) + _shape2 = list(x2.shape) + for i, j in zip(*_axes): + _shape1[i] = _shape2[j] = None + _shape1 = tuple([i for i in _shape1 if i is not None]) + _shape2 = tuple([i for i in _shape2 if i is not None]) + result_shape = _shape1 + _shape2 + ph.assert_result_shape('tensordot', [x1.shape, x2.shape], res.shape, + expected=result_shape) + _test_tensordot_stacks(x1, x2, kw, res) + +@pytest.mark.xp_extension('linalg') +@given( + *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()), + tensordot_kw, +) +def test_linalg_tensordot(x1, x2, kw): + _test_tensordot(linalg, x1, x2, kw) +@given( + *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()), + tensordot_kw, +) +def test_tensordot(x1, x2, kw): + _test_tensordot(_array_module, x1, x2, kw) @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()), + x=arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -601,17 +830,21 @@ def test_tensordot(dtypes, shape, data): def test_trace(x, kw): res = linalg.trace(x, **kw) - # TODO: trace() should promote in some cases. See - # https://github.com/data-apis/array-api/issues/202. See also the dtype - # argument to sum() below. - - # assert res.dtype == x.dtype, "trace() returned the wrong dtype" + dtype = kw.get("dtype", None) + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/160 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(res.dtype) # sanity check + else: + ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype) n, m = x.shape[-2:] - offset = kw.get('offset', 0) - assert res.shape == x.shape[:-2], "trace() returned the wrong shape" + ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2]) - def true_trace(x_stack): + def true_trace(x_stack, offset=0): # Note: the spec does not specify that offset must be within the # bounds of the matrix. A large offset should just produce a size 0 # diagonal in the last dimension (trace 0). See test_diagonal(). @@ -630,29 +863,94 @@ def true_trace(x_stack): _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace) +def _test_vecdot(namespace, x1, x2, data): + vecdot = namespace.vecdot + broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) + min_ndim = min(x1.ndim, x2.ndim) + ndim = len(broadcasted_shape) + kw = data.draw(kwargs(axis=integers(-min_ndim, -1))) + axis = kw.get('axis', -1) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + ph.raises(Exception, lambda: vecdot(x1, x2, **kw), + "vecdot did not raise an exception for invalid shapes") + return + expected_shape = list(broadcasted_shape) + expected_shape.pop(axis) + expected_shape = tuple(expected_shape) + + res = vecdot(x1, x2, **kw) + + ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) + ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=expected_shape) + + if x1.dtype in dh.int_dtypes: + def true_val(x, y, axis=-1): + return xp.sum(xp.multiply(x, y), dtype=res.dtype) + else: + true_val = None + + _test_stacks(vecdot, x1, x2, res=res, dims=0, + matrix_axes=(axis,), true_val=true_val) + +@pytest.mark.xp_extension('linalg') @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.real_dtypes), - shape=shapes(min_dims=1), - data=data(), + *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), + data(), ) -def test_vecdot(dtypes, shape, data): - # TODO: vary shapes, test different axis arguments - x1 = data.draw(arrays(dtype=dtypes[0], shape=shape), label="x1") - x2 = data.draw(arrays(dtype=dtypes[1], shape=shape), label="x2") - kw = data.draw(kwargs(axis=just(-1))) - - out = xp.vecdot(x1, x2, **kw) +def test_linalg_vecdot(x1, x2, data): + _test_vecdot(linalg, x1, x2, data) - ph.assert_dtype("vecdot", in_dtype=dtypes, out_dtype=out.dtype) - # TODO: assert shape and elements +@given( + *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), + data(), +) +def test_vecdot(x1, x2, data): + _test_vecdot(_array_module, x1, x2, data) +# Insanely large orders might not work. There isn't a limit specified in the +# spec, so we just limit to reasonable values here. +max_ord = 100 @pytest.mark.xp_extension('linalg') @given( - x=arrays(dtype=xps.floating_dtypes(), shape=shapes()), - kw=kwargs(axis=todo, keepdims=todo, ord=todo) + x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)), + data=data(), ) -def test_vector_norm(x, kw): - # res = linalg.vector_norm(x, **kw) - pass +def test_vector_norm(x, data): + kw = data.draw( + # We use data because axes is parameterized on x.ndim + kwargs(axis=axes(x.ndim), + keepdims=booleans(), + ord=one_of( + sampled_from([2, 1, 0, -1, -2, float("inf"), float("-inf")]), + integers(-max_ord, max_ord), + floats(-max_ord, max_ord), + )), label="kw") + + + res = linalg.vector_norm(x, **kw) + axis = kw.get('axis', None) + keepdims = kw.get('keepdims', False) + # TODO: Check that the ord values give the correct norms. + # ord = kw.get('ord', 2) + + _axes = sh.normalise_axis(axis, x.ndim) + + ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape, + in_shape=x.shape, axes=_axes, + keepdims=keepdims, kw=kw) + expected_dtype = dh.as_real_dtype(x.dtype) + ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype, + out_dtype=res.dtype, expected=expected_dtype) + + _kw = kw.copy() + _kw.pop('axis', None) + _test_stacks(linalg.vector_norm, x, res=res, + dims=x.ndim if keepdims else 0, + matrix_axes=_axes, **_kw + ) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index fd27b2dc..8cbe7750 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -250,6 +250,7 @@ def reshape_shapes(draw, shape): return tuple(rshape) +@pytest.mark.skip("flaky") # TODO: fix! @given( x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), data=st.data(), diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 4ae1f005..e5d32ee1 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -11,7 +11,7 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps, api_version +from . import xps from ._array_module import _UndefinedStub from .typing import DataType @@ -130,44 +130,15 @@ def test_prod(x, data): out = xp.prod(x, **kw) dtype = kw.get("dtype", None) - if dtype is None: - if dh.is_int_dtype(x.dtype): - if x.dtype in dh.uint_dtypes: - default_dtype = dh.default_uint - else: - default_dtype = dh.default_int - if default_dtype is None: - _dtype = None - else: - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype - else: - _dtype = default_dtype - elif dh.is_float_dtype(x.dtype, include_complex=False): - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: - _dtype = x.dtype - else: - _dtype = dh.default_float - elif api_version > "2021.12": - # Complex dtype - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: - _dtype = x.dtype - else: - _dtype = dh.default_complex - else: - raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") - else: - _dtype = dtype - if _dtype is None: + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: # If a default uint cannot exist (i.e. in PyTorch which doesn't support # uint32 or uint64), we skip testing the output dtype. # See https://github.com/data-apis/array-api-tests/issues/106 if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: - ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) + ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw @@ -247,44 +218,15 @@ def test_sum(x, data): out = xp.sum(x, **kw) dtype = kw.get("dtype", None) - if dtype is None: - if dh.is_int_dtype(x.dtype): - if x.dtype in dh.uint_dtypes: - default_dtype = dh.default_uint - else: - default_dtype = dh.default_int - if default_dtype is None: - _dtype = None - else: - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype - else: - _dtype = default_dtype - elif dh.is_float_dtype(x.dtype, include_complex=False): - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: - _dtype = x.dtype - else: - _dtype = dh.default_float - elif api_version > "2021.12": - # Complex dtype - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]: - _dtype = x.dtype - else: - _dtype = dh.default_complex - else: - raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") - else: - _dtype = dtype - if _dtype is None: + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: # If a default uint cannot exist (i.e. in PyTorch which doesn't support # uint32 or uint64), we skip testing the output dtype. # See https://github.com/data-apis/array-api-tests/issues/160 if x.dtype in dh.uint_dtypes: assert dh.is_int_dtype(out.dtype) # sanity check else: - ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype) + ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw diff --git a/numpy-skips.txt b/numpy-skips.txt index 0c6f39ae..aebc249d 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -14,6 +14,8 @@ array_api_tests/test_constants.py::test_newaxis # linalg.solve issue in numpy.array_api as of v1.26.2 (see numpy#25146) array_api_tests/test_linalg.py::test_solve +# numpy.array_api needs updating... or replaced on CI +array_api_tests/test_linalg.py::test_cross # https://github.com/numpy/numpy/issues/21373 array_api_tests/test_array_object.py::test_getitem diff --git a/requirements.txt b/requirements.txt index bb33bc90..7e898020 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest pytest-json-report hypothesis>=6.68.0 -ndindex>=1.6 +ndindex>=1.8