Skip to content

Commit

Permalink
[sparse] fix logic bug in GPU testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 12, 2022
1 parent 13c34f9 commit d8f0d22
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions tests/sparse_test.py
Expand Up @@ -1204,13 +1204,15 @@ def test_bcoo_batched_matmat_default_lowering(
sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general,
dimension_numbers=dimension_numbers))

warning_context = contextlib.nullcontext()
if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*"):
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
warning_context = self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*")

with warning_context:
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback)

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
Expand All @@ -1234,16 +1236,18 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general,
dimension_numbers=dimension_numbers_2d))

matmat_expected = lax.dot_general(lhs_mat_dense, rhs,
dimension_numbers=dimension_numbers_2d)
warning_context = contextlib.nullcontext()
if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
warning_context = self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs)
"bcoo_dot_general GPU lowering requires matrices with sorted indices*")

with self.subTest(msg="2D"):
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)
matmat_expected = lax.dot_general(lhs_mat_dense, rhs,
dimension_numbers=dimension_numbers_2d)
with self.subTest(msg="2D"):
with warning_context:
matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs)
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)

lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5)
Expand All @@ -1260,14 +1264,10 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
vecmat_expected = lax.dot_general(lhs_vec_dense, rhs,
dimension_numbers=dimension_numbers_1d)

if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
with self.subTest(msg="1D"):
with warning_context:
vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs)

with self.subTest(msg="1D"):
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)

@jtu.sample_product(
props=_generate_bcoo_dot_general_properties(
Expand Down

0 comments on commit d8f0d22

Please sign in to comment.