From 2e959904f496bec9bc04a41a99d5b7d9f2ad449c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Dec 2022 09:38:17 -0800 Subject: [PATCH] [sparse] fix logic bug in GPU testcase --- tests/sparse_test.py | 49 +++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 21d02c6e1fb2..212cff547b79 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -151,7 +151,7 @@ def gpu_dense_conversion_warning_context(self, dtype): return self.assertWarns(sparse.CuSparseEfficiencyWarning) return contextlib.nullcontext() - def gpu_matmul_warning_context(self, dtype): + def gpu_matmul_dtype_warning_context(self, dtype): if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]: return self.assertWarns(sparse.CuSparseEfficiencyWarning) return contextlib.nullcontext() @@ -318,7 +318,7 @@ def test_csr_matvec(self, shape, dtype, transpose): matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) - with self.gpu_matmul_warning_context(dtype): + with self.gpu_matmul_dtype_warning_context(dtype): self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @jtu.sample_product( @@ -338,7 +338,7 @@ def test_csr_matmat(self, shape, dtype, transpose): matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) - with self.gpu_matmul_warning_context(dtype): + with self.gpu_matmul_dtype_warning_context(dtype): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) @jtu.sample_product( @@ -397,7 +397,7 @@ def test_coo_matvec(self, shape, dtype, transpose): matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose) self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL) - with self.gpu_matmul_warning_context(dtype): + with self.gpu_matmul_dtype_warning_context(dtype): self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL) @jtu.sample_product( @@ -418,7 +418,7 @@ def test_coo_matmat(self, shape, dtype, transpose): matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose) self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL) - with self.gpu_matmul_warning_context(dtype): + with self.gpu_matmul_dtype_warning_context(dtype): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL) def test_coo_matmat_layout(self): @@ -654,6 +654,11 @@ def test_coo_matmul_ad(self, shape, dtype, bshape): class BCOOTest(sptu.SparseTestCase): + def gpu_matmul_warning_context(self, msg): + if GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering: + return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg) + return contextlib.nullcontext() + def test_vmappable(self): """Test does not depend on batching rules of BCOO primitives.""" M = jnp.arange(9).reshape((3, 3)) @@ -1166,10 +1171,8 @@ def f_sparse(lhs_bcoo, lhs, rhs): else: lhs_bcoo, lhs, rhs = args_maker() matmat_expected = f_dense(lhs_bcoo, lhs, rhs) - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, - "bcoo_dot_general GPU lowering currently does not support this " - "batch-mode computation.*"): + with self.gpu_matmul_warning_context( + "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): matmat_default_lowering_fallback = jit(f_sparse)(lhs_bcoo, lhs, rhs) self.assertAllClose(matmat_expected, matmat_default_lowering_fallback, atol=1E-6, rtol=1E-6) @@ -1204,13 +1207,9 @@ def test_bcoo_batched_matmat_default_lowering( sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, dimension_numbers=dimension_numbers)) - if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, - "bcoo_dot_general GPU lowering currently does not support this " - "batch-mode computation.*"): - matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) - + with self.gpu_matmul_warning_context( + "bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"): + matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") @@ -1236,14 +1235,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): matmat_expected = lax.dot_general(lhs_mat_dense, rhs, dimension_numbers=dimension_numbers_2d) - if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, + with self.subTest(msg="2D"): + with self.gpu_matmul_warning_context( "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs) - - with self.subTest(msg="2D"): - self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) + self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32) lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5) @@ -1260,14 +1256,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): vecmat_expected = lax.dot_general(lhs_vec_dense, rhs, dimension_numbers=dimension_numbers_1d) - if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, + with self.subTest(msg="1D"): + with self.gpu_matmul_warning_context( "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs) - - with self.subTest(msg="1D"): - self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) + self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) @jtu.sample_product( props=_generate_bcoo_dot_general_properties(