From d8f0d229c98e92347e4c4a7d0d5ceae7352bed3e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Dec 2022 09:01:52 -0800 Subject: [PATCH] [sparse] fix logic bug in GPU testcase --- tests/sparse_test.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 21d02c6e1fb2..5a03d1bafd44 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -1204,13 +1204,15 @@ def test_bcoo_batched_matmat_default_lowering( sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, dimension_numbers=dimension_numbers)) + warning_context = contextlib.nullcontext() if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, - "bcoo_dot_general GPU lowering currently does not support this " - "batch-mode computation.*"): - matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) + warning_context = self.assertWarnsRegex( + sparse.CuSparseEfficiencyWarning, + "bcoo_dot_general GPU lowering currently does not support this " + "batch-mode computation.*") + with warning_context: + matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs) self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback) @unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse") @@ -1234,16 +1236,18 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general, dimension_numbers=dimension_numbers_2d)) - matmat_expected = lax.dot_general(lhs_mat_dense, rhs, - dimension_numbers=dimension_numbers_2d) + warning_context = contextlib.nullcontext() if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( + warning_context = self.assertWarnsRegex( sparse.CuSparseEfficiencyWarning, - "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): - matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs) + "bcoo_dot_general GPU lowering requires matrices with sorted indices*") - with self.subTest(msg="2D"): - self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) + matmat_expected = lax.dot_general(lhs_mat_dense, rhs, + dimension_numbers=dimension_numbers_2d) + with self.subTest(msg="2D"): + with warning_context: + matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs) + self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback) lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32) lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5) @@ -1260,14 +1264,10 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self): vecmat_expected = lax.dot_general(lhs_vec_dense, rhs, dimension_numbers=dimension_numbers_1d) - if config.jax_bcoo_cusparse_lowering: - with self.assertWarnsRegex( - sparse.CuSparseEfficiencyWarning, - "bcoo_dot_general GPU lowering requires matrices with sorted indices*"): + with self.subTest(msg="1D"): + with warning_context: vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs) - - with self.subTest(msg="1D"): - self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) + self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback) @jtu.sample_product( props=_generate_bcoo_dot_general_properties(