Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse] fix logic bug in GPU testcase #13616

Merged
merged 1 commit into from Dec 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 21 additions & 28 deletions tests/sparse_test.py
Expand Up @@ -151,7 +151,7 @@ def gpu_dense_conversion_warning_context(self, dtype):
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()

def gpu_matmul_warning_context(self, dtype):
def gpu_matmul_dtype_warning_context(self, dtype):
if jtu.device_under_test() == "gpu" and dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
return self.assertWarns(sparse.CuSparseEfficiencyWarning)
return contextlib.nullcontext()
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_csr_matvec(self, shape, dtype, transpose):
matvec = lambda *args: sparse.csr_matvec(*args, shape=M.shape, transpose=transpose)

self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand All @@ -338,7 +338,7 @@ def test_csr_matmat(self, shape, dtype, transpose):
matmat = lambda *args: sparse.csr_matmat(*args, shape=shape, transpose=transpose)

self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_coo_matvec(self, shape, dtype, transpose):
matvec = lambda *args: sparse_coo._coo_matvec(*args, spinfo=sparse_coo.COOInfo(shape=M.shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ v, matvec(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ v, jit(matvec)(*args), rtol=MATMUL_TOL)

@jtu.sample_product(
Expand All @@ -418,7 +418,7 @@ def test_coo_matmat(self, shape, dtype, transpose):
matmat = lambda *args: sparse_coo._coo_matmat(*args, spinfo=sparse_coo.COOInfo(shape=shape, rows_sorted=True), transpose=transpose)

self.assertAllClose(op(M) @ B, matmat(*args), rtol=MATMUL_TOL)
with self.gpu_matmul_warning_context(dtype):
with self.gpu_matmul_dtype_warning_context(dtype):
self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=MATMUL_TOL)

def test_coo_matmat_layout(self):
Expand Down Expand Up @@ -654,6 +654,11 @@ def test_coo_matmul_ad(self, shape, dtype, bshape):

class BCOOTest(sptu.SparseTestCase):

def gpu_matmul_warning_context(self, msg):
if GPU_LOWERING_ENABLED and config.jax_bcoo_cusparse_lowering:
return self.assertWarnsRegex(sparse.CuSparseEfficiencyWarning, msg)
return contextlib.nullcontext()

def test_vmappable(self):
"""Test does not depend on batching rules of BCOO primitives."""
M = jnp.arange(9).reshape((3, 3))
Expand Down Expand Up @@ -1166,10 +1171,8 @@ def f_sparse(lhs_bcoo, lhs, rhs):
else:
lhs_bcoo, lhs, rhs = args_maker()
matmat_expected = f_dense(lhs_bcoo, lhs, rhs)
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
matmat_default_lowering_fallback = jit(f_sparse)(lhs_bcoo, lhs, rhs)
self.assertAllClose(matmat_expected, matmat_default_lowering_fallback,
atol=1E-6, rtol=1E-6)
Expand Down Expand Up @@ -1204,13 +1207,9 @@ def test_bcoo_batched_matmat_default_lowering(
sp_matmat = jit(partial(sparse_bcoo.bcoo_dot_general,
dimension_numbers=dimension_numbers))

if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
"bcoo_dot_general GPU lowering currently does not support this "
"batch-mode computation.*"):
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)

with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering currently does not support this batch-mode computation.*"):
matmat_default_lowering_fallback = sp_matmat(lhs_bcoo, rhs)
self.assertArraysEqual(matmat_expected, matmat_default_lowering_fallback)

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
Expand All @@ -1236,14 +1235,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):

matmat_expected = lax.dot_general(lhs_mat_dense, rhs,
dimension_numbers=dimension_numbers_2d)
if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
with self.subTest(msg="2D"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
matmat_unsorted_fallback = sp_matmat(lhs_mat_bcoo_unsorted, rhs)

with self.subTest(msg="2D"):
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)
self.assertArraysEqual(matmat_expected, matmat_unsorted_fallback)

lhs_vec_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
lhs_vec_bcoo = sparse.BCOO.fromdense(lhs_vec_dense, nse=5)
Expand All @@ -1260,14 +1256,11 @@ def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
vecmat_expected = lax.dot_general(lhs_vec_dense, rhs,
dimension_numbers=dimension_numbers_1d)

if config.jax_bcoo_cusparse_lowering:
with self.assertWarnsRegex(
sparse.CuSparseEfficiencyWarning,
with self.subTest(msg="1D"):
with self.gpu_matmul_warning_context(
"bcoo_dot_general GPU lowering requires matrices with sorted indices*"):
vecmat_unsorted_fallback = sp_vecmat(lhs_vec_bcoo_unsorted, rhs)

with self.subTest(msg="1D"):
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)
self.assertArraysEqual(vecmat_expected, vecmat_unsorted_fallback)

@jtu.sample_product(
props=_generate_bcoo_dot_general_properties(
Expand Down