diff --git a/paddle/fluid/platform/dynload/cusparse.cc b/paddle/fluid/platform/dynload/cusparse.cc index da93455e8bc7d..756737c1a169f 100644 --- a/paddle/fluid/platform/dynload/cusparse.cc +++ b/paddle/fluid/platform/dynload/cusparse.cc @@ -28,6 +28,10 @@ CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); #endif +#ifdef CUSPARSE_ROUTINE_EACH_R3 +CUSPARSE_ROUTINE_EACH_R3(DEFINE_WRAP); +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index e32ce5b21540b..e816824b82f72 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -297,7 +297,7 @@ args : (Tensor x, Tensor y, Tensor mask) output : Tensor(out) kernel : - func : csr_masked_matmul{dense, dense, sparse_csr -> sparse_csr} + func : masked_matmul_csr{dense, dense, sparse_csr -> sparse_csr} layout : x backward: masked_matmul_grad @@ -305,10 +305,10 @@ args : (Tensor x, Tensor y) output : Tensor(out) kernel : - func : csr_dense_matmul{sparse_csr, dense -> dense}, - csr_csr_matmul{sparse_csr, sparse_csr -> sparse_csr}, - coo_dense_matmul{sparse_coo, dense -> dense}, - coo_coo_matmul{sparse_coo, sparse_coo -> sparse_coo} + func : matmul_csr_dense {sparse_csr, dense -> dense}, + matmul_csr_csr {sparse_csr, sparse_csr -> sparse_csr}, + matmul_coo_dense {sparse_coo, dense -> dense}, + matmul_coo_coo {sparse_coo, sparse_coo -> sparse_coo} layout : x backward: matmul_grad diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index 6e3a82a22bcfc..68e6020ac3626 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -125,14 +125,17 @@ args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) kernel : - func : csr_masked_matmul_grad{dense, dense, sparse_csr -> dense, dense} + func : masked_matmul_csr_grad{dense, dense, sparse_csr -> dense, dense} - backward_api : matmul_grad forward : matmul(Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) output : Tensor(x_grad), Tensor(y_grad) kernel : - func : csr_dense_matmul_grad{sparse_csr, dense, dense -> sparse_csr, dense} + func : matmul_csr_dense_grad {sparse_csr, dense, dense -> sparse_csr, dense}, + matmul_csr_csr_grad {sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}, + matmul_coo_dense_grad {sparse_coo, dense, dense -> sparse_coo, dense}, + matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo} - backward_api : multiply_grad forward : multiply(Tensor x, Tensor y) -> Tensor(out) diff --git a/paddle/phi/backends/dynload/cusparse.cc b/paddle/phi/backends/dynload/cusparse.cc index 013211064b8e4..ce8f87dc3cdfa 100644 --- a/paddle/phi/backends/dynload/cusparse.cc +++ b/paddle/phi/backends/dynload/cusparse.cc @@ -30,5 +30,9 @@ CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); #endif +#ifdef CUSPARSE_ROUTINE_EACH_R3 +CUSPARSE_ROUTINE_EACH_R3(DEFINE_WRAP); +#endif + } // namespace dynload } // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 3d92674c92d6e..9f7be26857bdb 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -298,6 +298,7 @@ class CuSparseDnVecDescriptor { cusparseDnVecDescr_t descriptor_; }; +/************* SPARSE*DENSE->DENSE MATMUL ************/ template <> template void SparseBlas::SPMM(bool transa, @@ -345,6 +346,7 @@ void SparseBlas::SPMM(bool transa, }); } +/************* SPARSE*DENSE->DENSE MV ************/ template <> template void SparseBlas::SPMV(bool transa, @@ -389,6 +391,7 @@ void SparseBlas::SPMV(bool transa, }); } +/************* DENSE*DENSE->SPARSE MATMUL ************/ #if CUDA_VERSION >= 11030 template <> template diff --git a/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc index cd1665b66431b..2586976b7636c 100644 --- a/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc @@ -22,7 +22,7 @@ namespace sparse { // TODO(zhouwei25): implement CPU backward kernel of " CSR @ DENSE -> DENSE" template -void CsrDenseMatmulGradKernel(const Context& dev_ctx, +void MatmulCsrDenseGradKernel(const Context& dev_ctx, const SparseCsrTensor& x, const DenseTensor& y, const DenseTensor& dout, @@ -34,7 +34,7 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx, // TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" template -void CsrMaskedMatmulGradKernel(const Context& dev_ctx, +void MaskedMatmulCsrGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& dout, @@ -47,18 +47,18 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(csr_dense_matmul_grad, +PD_REGISTER_KERNEL(matmul_csr_dense_grad, CPU, ALL_LAYOUT, - phi::sparse::CsrDenseMatmulGradKernel, + phi::sparse::MatmulCsrDenseGradKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } -PD_REGISTER_KERNEL(csr_masked_matmul_grad, +PD_REGISTER_KERNEL(masked_matmul_csr_grad, CPU, ALL_LAYOUT, - phi::sparse::CsrMaskedMatmulGradKernel, + phi::sparse::MaskedMatmulCsrGradKernel, float, double) {} diff --git a/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc index 0818b8e900a05..8db0ccfd575e5 100644 --- a/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc @@ -22,7 +22,7 @@ namespace sparse { // TODO(zhouwei25): implement CPU kernel of " CSR @ DENSE -> DENSE" template -void CsrDenseMatmulKernel(const Context& dev_ctx, +void MatmulCsrDenseKernel(const Context& dev_ctx, const SparseCsrTensor& x, const DenseTensor& y, DenseTensor* out) { @@ -32,7 +32,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, // TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" template -void CsrMaskedMatmulKernel(const Context& dev_ctx, +void MaskedMatmulCsrKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& mask, @@ -44,18 +44,18 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(csr_dense_matmul, +PD_REGISTER_KERNEL(matmul_csr_dense, CPU, ALL_LAYOUT, - phi::sparse::CsrDenseMatmulKernel, + phi::sparse::MatmulCsrDenseKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } -PD_REGISTER_KERNEL(csr_masked_matmul, +PD_REGISTER_KERNEL(masked_matmul_csr, CPU, ALL_LAYOUT, - phi::sparse::CsrMaskedMatmulKernel, + phi::sparse::MaskedMatmulCsrKernel, float, double) {} diff --git a/paddle/phi/kernels/sparse/empty_kernel.cc b/paddle/phi/kernels/sparse/empty_kernel.cc index c1706b9919d90..115611a272d94 100644 --- a/paddle/phi/kernels/sparse/empty_kernel.cc +++ b/paddle/phi/kernels/sparse/empty_kernel.cc @@ -26,37 +26,27 @@ template void EmptyLikeCooKernel(const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out) { - const DenseTensor& x_indices = x.non_zero_indices(); + out->set_dims(x.dims()); + *(out->mutable_non_zero_indices()) = x.non_zero_indices(); + const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_indices = out->mutable_non_zero_indices(); DenseTensor* out_values = out->mutable_non_zero_elements(); - - phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices); - out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); - - out->set_dims(x.dims()); } template void EmptyLikeCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out) { - const DenseTensor& x_crows = x.non_zero_crows(); - const DenseTensor& x_cols = x.non_zero_cols(); + out->set_dims(x.dims()); + *(out->mutable_non_zero_crows()) = x.non_zero_crows(); + *(out->mutable_non_zero_cols()) = x.non_zero_cols(); + const DenseTensor& x_values = x.non_zero_elements(); - DenseTensor* out_crows = out->mutable_non_zero_crows(); - DenseTensor* out_cols = out->mutable_non_zero_cols(); DenseTensor* out_values = out->mutable_non_zero_elements(); - - phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows); - phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols); - out_values->Resize(x_values.dims()); dev_ctx.template Alloc(out_values); - - out->set_dims(x.dims()); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu index d5c128fea6f29..c4bb66827e35a 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -22,13 +22,52 @@ limitations under the License. */ #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { namespace sparse { template -void CsrDenseMatmulGradKernel(const Context& dev_ctx, +void MatmulCooDenseGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + SparseCooTensor* dx, + DenseTensor* dy) { +#if CUDA_VERSION >= 11030 + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + + // dx{SparseCoo} = dout{Dense} * y'{Dense} + if (dx) { + // 'cusparseSDDMM' only support CSR now, so use COO->CSR->COO, + // which will increase some expenses. + EmptyLikeCooKernel(dev_ctx, x, dx); + SparseCsrTensor dx_csr = SparseCooToCsr(dev_ctx, *dx); + sparse_blas.SDDMM( + false, true, static_cast(1), dout, y, static_cast(0), &dx_csr); + SparseCsrToCooKernel(dev_ctx, dx_csr, dx); + } + + // dy{Dense} = x'{SparseCoo} * dout{Dense} + if (dy) { + MetaTensor meta_dy(dy); + meta_dy.set_dims(y.dims()); + meta_dy.set_dtype(y.dtype()); + dev_ctx.template Alloc(dy); + + sparse_blas.SPMM( + true, false, static_cast(1), x, dout, static_cast(0), dy); + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + "backward of 'sparse.matmul' use cusparseSDDMM, which is supported from " + "CUDA 11.3")); +#endif +} + +template +void MatmulCsrDenseGradKernel(const Context& dev_ctx, const SparseCsrTensor& x, const DenseTensor& y, const DenseTensor& dout, @@ -66,7 +105,7 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx, } template -void CsrMaskedMatmulGradKernel(const Context& dev_ctx, +void MaskedMatmulCsrGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& dout, @@ -119,18 +158,27 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(csr_dense_matmul_grad, +PD_REGISTER_KERNEL(matmul_coo_dense_grad, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCooDenseGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(matmul_csr_dense_grad, GPU, ALL_LAYOUT, - phi::sparse::CsrDenseMatmulGradKernel, + phi::sparse::MatmulCsrDenseGradKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } -PD_REGISTER_KERNEL(csr_masked_matmul_grad, +PD_REGISTER_KERNEL(masked_matmul_csr_grad, GPU, ALL_LAYOUT, - phi::sparse::CsrMaskedMatmulGradKernel, + phi::sparse::MaskedMatmulCsrGradKernel, float, double) {} diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index 69cd4bac0c763..3adbce0dd17df 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -31,11 +31,11 @@ limitations under the License. */ namespace phi { namespace sparse { -template -void CsrDenseMatmulKernel(const Context& dev_ctx, - const SparseCsrTensor& x, - const DenseTensor& y, - DenseTensor* out) { +template +void MatmulKernelImpl(const Context& dev_ctx, + const TensorType& x, + const DenseTensor& y, + DenseTensor* out) { #if CUDA_VERSION >= 11000 std::vector xdim_vec = phi::vectorize(x.dims()); std::vector ydim_vec = phi::vectorize(y.dims()); @@ -91,7 +91,23 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, } template -void CsrMaskedMatmulKernel(const Context& dev_ctx, +void MatmulCooDenseKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + DenseTensor* out) { + MatmulKernelImpl(dev_ctx, x, y, out); +} + +template +void MatmulCsrDenseKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + DenseTensor* out) { + MatmulKernelImpl(dev_ctx, x, y, out); +} + +template +void MaskedMatmulCsrKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& mask, @@ -176,18 +192,27 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(csr_dense_matmul, +PD_REGISTER_KERNEL(matmul_csr_dense, GPU, ALL_LAYOUT, - phi::sparse::CsrDenseMatmulKernel, + phi::sparse::MatmulCsrDenseKernel, float, double) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } -PD_REGISTER_KERNEL(csr_masked_matmul, +PD_REGISTER_KERNEL(matmul_coo_dense, + GPU, + ALL_LAYOUT, + phi::sparse::MatmulCooDenseKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(masked_matmul_csr, GPU, ALL_LAYOUT, - phi::sparse::CsrMaskedMatmulKernel, + phi::sparse::MaskedMatmulCsrKernel, float, double) {} diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 231fc551f4788..2639753266db6 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -134,7 +134,7 @@ void CastCooKernel(const Context& dev_ctx, DenseTensor* out_values = out->mutable_non_zero_elements(); if (index_dtype == DataType::UNDEFINED) { - phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices); + *out_indices = x_indices; } else { phi::MetaTensor meta(out_indices); meta.set_dims(x_indices.dims()); @@ -172,8 +172,8 @@ void CastCsrKernel(const Context& dev_ctx, DenseTensor* out_values = out->mutable_non_zero_elements(); if (index_dtype == DataType::UNDEFINED) { - phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows); - phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols); + *out_crows = x_crows; + *out_cols = x_cols; } else { phi::MetaTensor crows_meta(out_crows); crows_meta.set_dims(x_crows.dims()); diff --git a/paddle/phi/kernels/sparse/matmul_grad_kernel.h b/paddle/phi/kernels/sparse/matmul_grad_kernel.h index 787691f3515d6..4acb7bb7e1eb5 100644 --- a/paddle/phi/kernels/sparse/matmul_grad_kernel.h +++ b/paddle/phi/kernels/sparse/matmul_grad_kernel.h @@ -23,16 +23,16 @@ namespace sparse { // TODO(zhouwei25): implement Backward of " COO @ COO -> COO" template -void CooCooMatmulGradKernel(const Context& dev_ctx, +void MatmulCooCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const SparseCooTensor& y, const SparseCooTensor& dout, SparseCooTensor* dx, SparseCooTensor* dy); -// TODO(zhouwei25): implement Backward of " COO @ DENSE -> DENSE" +// Backward of " COO @ DENSE -> DENSE" template -void CooDenseMatmulGradKernel(const Context& dev_ctx, +void MatmulCooDenseGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& y, const DenseTensor& dout, @@ -41,7 +41,7 @@ void CooDenseMatmulGradKernel(const Context& dev_ctx, // TODO(zhouwei25): implement Backward of " CSR @ CSR -> CSR" template -void CsrCsrMatmulGradKernel(const Context& dev_ctx, +void MatmulCsrCsrGradKernel(const Context& dev_ctx, const SparseCsrTensor& x, const SparseCsrTensor& y, const SparseCsrTensor& dout, @@ -50,7 +50,7 @@ void CsrCsrMatmulGradKernel(const Context& dev_ctx, /* Backward of "CSR @ DENSE -> DENSE" */ template -void CsrDenseMatmulGradKernel(const Context& dev_ctx, +void MatmulCsrDenseGradKernel(const Context& dev_ctx, const SparseCsrTensor& x, const DenseTensor& y, const DenseTensor& dout, @@ -59,7 +59,7 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx, /* Backward of "DENSE @ DENSE * CSR_MASK -> CSR" */ template -void CsrMaskedMatmulGradKernel(const Context& dev_ctx, +void MaskedMatmulCsrGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& dout, diff --git a/paddle/phi/kernels/sparse/matmul_kernel.h b/paddle/phi/kernels/sparse/matmul_kernel.h index d9093a020c207..a261bbf3cd3f7 100644 --- a/paddle/phi/kernels/sparse/matmul_kernel.h +++ b/paddle/phi/kernels/sparse/matmul_kernel.h @@ -23,35 +23,35 @@ namespace sparse { // TODO(zhouwei25): implement " COO @ COO -> COO" template -void CooCooMatmulKernel(const Context& dev_ctx, +void MatmulCooCooKernel(const Context& dev_ctx, const SparseCooTensor& x, const SparseCooTensor& y, SparseCooTensor* out); -// TODO(zhouwei25): implement " COO @ DENSE -> DENSE" +/* COO @ DENSE -> DENSE */ template -void CooDenseMatmulKernel(const Context& dev_ctx, +void MatmulCooDenseKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& y, DenseTensor* out); // TODO(zhouwei25): implement " CSR @ CSR -> CSR" template -void CsrCsrMatmulKernel(const Context& dev_ctx, +void MatmulCsrCsrKernel(const Context& dev_ctx, const SparseCsrTensor& x, const SparseCsrTensor& y, SparseCsrTensor* out); /* CSR @ DENSE -> DENSE */ template -void CsrDenseMatmulKernel(const Context& dev_ctx, +void MatmulCsrDenseKernel(const Context& dev_ctx, const SparseCsrTensor& x, const DenseTensor& y, DenseTensor* out); /* DENSE @ DENSE * CSR_MASK -> CSR */ template -void CsrMaskedMatmulKernel(const Context& dev_ctx, +void MaskedMatmulCsrKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, const SparseCsrTensor& mask, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py index 96adf959b2b6e..8986d4a7ef5d2 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py @@ -13,8 +13,6 @@ # limitations under the License. import paddle -from paddle.fluid.framework import _test_eager_guard - import numpy as np import scipy import scipy.sparse as sp @@ -22,7 +20,7 @@ import os import re -np.random.seed(2022) +paddle.set_default_dtype('float64') def get_cuda_version(): @@ -37,153 +35,115 @@ def get_cuda_version(): return -1 -@unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, - "paddle is not compiled with CUDA and cuda version need to >= 11.0") -class TestCsrDenseMatmul2D(unittest.TestCase): - # x: csr, y: dense, out: dense - def test_matmul(self): - with _test_eager_guard(): - mask = np.random.rand(10, 12) < 0.2 - np_x = np.random.rand(10, 12) * mask - - np_csr = sp.csr_matrix(np_x) - np_dense = np.random.rand(12, 6) - np_out = np_csr @ np_dense - - np_out_grad = np.ones([10, 6]) - - # dx(csr) = dout(dense) * y'(dense) * mask - np_csr_grad = sp.csr_matrix( - np.matmul(np_out_grad, np_dense.transpose(1, 0)) * mask) - # dy(dense) = x'(csr) * dout(dense) - np_dense_grad = np_csr.transpose() @ np_out_grad - - csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr() - dense = paddle.to_tensor(np_dense, stop_gradient=False) - out = paddle.incubate.sparse.matmul(csr, dense) - - self.assertTrue(np.allclose(np_out, out.numpy())) - - if get_cuda_version() >= 11030: - out.backward() - self.assertTrue( - np.allclose(np_csr_grad.indptr, - csr.grad.crows().numpy())) - self.assertTrue( - np.allclose(np_csr_grad.indices, - csr.grad.cols().numpy())) - self.assertTrue( - np.allclose(np_csr_grad.data, - csr.grad.values().numpy())) - - self.assertTrue(np.allclose(np_dense_grad, dense.grad.numpy())) - - -@unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11030, - "paddle is not compiled with CUDA and cuda version need to >= 11.3") -class TestCsrMaskedMatmul2D(unittest.TestCase): - # x: dense, y: dense, out: csr - def test_matmul(self): - with _test_eager_guard(): - np_mask = np.random.rand(10, 6) < 0.2 - - np_x = np.random.rand(10, 12) - np_y = np.random.rand(12, 6) - np_out = sp.csr_matrix(np.matmul(np_x, np_y) * np_mask) - - np_out_grad = sp.csr_matrix(np.ones([10, 6]) * np_mask) - # dx(dense) = dout(csr) * y'(dense) - np_x_grad = np_out_grad @ np_y.transpose(1, 0) - # dy(dense) = x'(dense) * dout(csr) -> dy'(dense) = dout'(csr) * x(dense) - np_y_grad = (np_out_grad.transpose() @ np_x).transpose(1, 0) - - x = paddle.to_tensor(np_x, stop_gradient=False) - y = paddle.to_tensor(np_y, stop_gradient=False) - mask = paddle.to_tensor(np.ones([10, 6]) * np_mask).to_sparse_csr() - out = paddle.incubate.sparse.masked_matmul(x, y, mask) - - self.assertTrue(np.allclose(np_out.indptr, out.crows().numpy())) - self.assertTrue(np.allclose(np_out.indices, out.cols().numpy())) - self.assertTrue(np.allclose(np_out.data, out.values().numpy())) - - out.backward() - self.assertTrue(np.allclose(out.is_sparse_csr(), True)) - self.assertTrue(np.allclose(np_x_grad, x.grad.numpy())) - self.assertTrue(np.allclose(np_y_grad, y.grad.numpy())) - - -@unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11070, - "paddle is not compiled with CUDA and cuda version need to >= 11.7") -class TestCsrDenseMatmul3D(unittest.TestCase): - # x: csr, y: dense, out: dense - def test_matmul(self): - with _test_eager_guard(): - paddle.set_default_dtype('float32') - origin_x = paddle.rand([16, 16, 12]) - mask = paddle.randint(0, 2, [16, 12]) - origin_x = origin_x * mask - origin_y = paddle.rand([16, 12, 10]) - - dense_x = origin_x.detach() - dense_x.stop_gradient = False - dense_y = origin_y.detach() - dense_y.stop_gradient = False - dense_out = paddle.matmul(dense_x, dense_y) - dense_out.backward() - +class TestMatmul(unittest.TestCase): + # x: sparse, y: dense, out: dense + def check_result(self, x_shape, y_shape, format): + if len(x_shape) == 3: + mask = paddle.randint(0, 2, [x_shape[-2], x_shape[-1]]) + else: + mask = paddle.randint(0, 2, x_shape) + origin_x = paddle.rand(x_shape) * mask + origin_y = paddle.rand(y_shape) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = paddle.matmul(dense_x, dense_y) + + if format == "coo": + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) + else: sp_x = origin_x.detach().to_sparse_csr() - sp_x.stop_gradient = False - sp_y = origin_y.detach() - sp_y.stop_gradient = False - sp_out = paddle.incubate.sparse.matmul(sp_x, sp_y) - sp_out.backward() + sp_x.stop_gradient = False + sp_y = origin_y.detach() + sp_y.stop_gradient = False + sp_out = paddle.incubate.sparse.matmul(sp_x, sp_y) - self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + if get_cuda_version() >= 11030: + dense_out.backward() + sp_out.backward() self.assertTrue( np.allclose(sp_x.grad.to_dense().numpy(), (dense_x.grad * mask).numpy())) self.assertTrue(np.allclose(sp_y.grad.numpy(), dense_y.grad.numpy())) - -@unittest.skipIf( - not paddle.is_compiled_with_cuda() or get_cuda_version() < 11070, - "paddle is not compiled with CUDA and cuda version need to >= 11.7") -class TestCsrMaskedMatmul3D(unittest.TestCase): - # x: dense, y: dense, out: csr - def test_matmul(self): - with _test_eager_guard(): - paddle.set_default_dtype('float64') - origin_x = paddle.rand([16, 16, 12]) - origin_y = paddle.rand([16, 12, 10]) - - mask = paddle.randint(0, 2, [16, 10]) - - dense_x = origin_x.detach() - dense_x.stop_gradient = False - dense_y = origin_y.detach() - dense_y.stop_gradient = False - dense_out = paddle.matmul(dense_x, dense_y) - dense_out = dense_out * mask - dense_out.backward() - - sp_x = origin_x.detach() - sp_x.stop_gradient = False - sp_y = origin_y.detach() - sp_y.stop_gradient = False - sp_out = paddle.incubate.sparse.masked_matmul( - sp_x, sp_y, dense_out.to_sparse_csr()) - sp_out.backward() - - self.assertTrue( - np.allclose(sp_out.to_dense().numpy(), dense_out.numpy())) - self.assertTrue(np.allclose(sp_x.grad.numpy(), - dense_x.grad.numpy())) - self.assertTrue(np.allclose(sp_y.grad.numpy(), - dense_y.grad.numpy())) + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11000, "only support cuda>=11.0") + def test_matmul_2d(self): + self.check_result([16, 12], [12, 10], 'coo') + self.check_result([16, 12], [12, 10], 'csr') + + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11070, "only support cuda>=11.7") + def test_matmul_3d(self): + self.check_result([8, 16, 12], [8, 12, 10], 'coo') + self.check_result([8, 16, 12], [8, 12, 10], 'csr') + + +class TestMaskedMatmul(unittest.TestCase): + # x: dense, y: dense, out: sparse_`csr + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11030, + "only support on cuda>=11.3") + def test_masked_matmul_2d(self): + np_mask = np.random.rand(10, 6) < 0.2 + + np_x = np.random.rand(10, 12) + np_y = np.random.rand(12, 6) + np_out = sp.csr_matrix(np.matmul(np_x, np_y) * np_mask) + + np_out_grad = sp.csr_matrix(np.ones([10, 6]) * np_mask) + # dx(dense) = dout(csr) * y'(dense) + np_x_grad = np_out_grad @ np_y.transpose(1, 0) + # dy(dense) = x'(dense) * dout(csr) -> dy'(dense) = dout'(csr) * x(dense) + np_y_grad = (np_out_grad.transpose() @ np_x).transpose(1, 0) + + x = paddle.to_tensor(np_x, stop_gradient=False) + y = paddle.to_tensor(np_y, stop_gradient=False) + mask = paddle.to_tensor(np.ones([10, 6]) * np_mask).to_sparse_csr() + out = paddle.incubate.sparse.masked_matmul(x, y, mask) + + self.assertTrue(np.allclose(np_out.indptr, out.crows().numpy())) + self.assertTrue(np.allclose(np_out.indices, out.cols().numpy())) + self.assertTrue(np.allclose(np_out.data, out.values().numpy())) + + out.backward() + self.assertTrue(np.allclose(out.is_sparse_csr(), True)) + self.assertTrue(np.allclose(np_x_grad, x.grad.numpy())) + self.assertTrue(np.allclose(np_y_grad, y.grad.numpy())) + + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11070, + "only support on cuda>=11.7") + def test_masked_matmul_3d(self): + paddle.set_default_dtype('float32') + origin_x = paddle.rand([16, 16, 12]) + mask = paddle.randint(0, 2, [16, 12]) + origin_x = origin_x * mask + origin_y = paddle.rand([16, 12, 10]) + + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = paddle.matmul(dense_x, dense_y) + dense_out.backward() + + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_y = origin_y.detach() + sp_y.stop_gradient = False + sp_out = paddle.incubate.sparse.matmul(sp_x, sp_y) + sp_out.backward() + + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + self.assertTrue( + np.allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy())) + self.assertTrue(np.allclose(sp_y.grad.numpy(), dense_y.grad.numpy())) if __name__ == "__main__": diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py index 0c90cd92a7537..7a7861f7b20e7 100644 --- a/python/paddle/incubate/sparse/binary.py +++ b/python/paddle/incubate/sparse/binary.py @@ -62,29 +62,37 @@ def matmul(x, y, name=None): .. code-block:: python import paddle - from paddle.fluid.framework import _test_eager_guard - paddle.seed(100) # csr @ dense -> dense - - with _test_eager_guard(): - crows = [0, 2, 3, 5] - cols = [1, 3, 2, 0, 1] - values = [1., 2., 3., 4., 5.] - dense_shape = [3, 4] - csr = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) - # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 3, 5], - # cols=[1, 3, 2, 0, 1], - # values=[1., 2., 3., 4., 5.]) - dense = paddle.randn([4, 3]) - - out = paddle.incubate.sparse.matmul(csr, dense) - # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [[-1.94294846 , -3.33990622 , 0.62359387 ], - # [-4.12815523 , 3.46535444 , -3.27413893 ], - # [-0.15209436 , -19.23207283, -3.35593438 ]]) - + crows = [0, 1, 2, 3] + cols = [1, 2, 0] + values = [1., 2., 3.] + csr = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, [3, 3]) + # Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 1, 2, 3], + # cols=[1, 2, 0], + # values=[1., 2., 3.]) + dense = paddle.ones([3, 2]) + out = paddle.incubate.sparse.matmul(csr, dense) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1., 1.], + # [2., 2.], + # [3., 3.]]) + + # coo @ dense -> dense + indices = [[0, 1, 2], [1, 2, 0]] + values = [1., 2., 3.] + coo = paddle.incubate.sparse.sparse_coo_tensor(indices, values, [3, 3]) + # Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # indices=[[0, 1, 2], + # [1, 2, 0]], + # values=[1., 2., 3.]) + dense = paddle.ones([3, 2]) + out = paddle.incubate.sparse.matmul(coo, dense) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1., 1.], + # [2., 2.], + # [3., 3.]]) """ return _C_ops.final_state_sparse_matmul(x, y) @@ -123,30 +131,27 @@ def masked_matmul(x, y, mask, name=None): .. code-block:: python import paddle - from paddle.fluid.framework import _test_eager_guard paddle.seed(100) # dense @ dense * csr_mask -> csr - - with _test_eager_guard(): - crows = [0, 2, 3, 5] - cols = [1, 3, 2, 0, 1] - values = [1., 2., 3., 4., 5.] - dense_shape = [3, 4] - mask = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) - # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 3, 5], - # cols=[1, 3, 2, 0, 1], - # values=[1., 2., 3., 4., 5.]) - - x = paddle.rand([3, 5]) - y = paddle.rand([5, 4]) - - out = paddle.incubate.sparse.masked_matmul(x, y, mask) - # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 3, 5], - # cols=[1, 3, 2, 0, 1], - # values=[0.98986477, 0.97800624, 1.14591956, 0.68561077, 0.94714981]) + crows = [0, 2, 3, 5] + cols = [1, 3, 2, 0, 1] + values = [1., 2., 3., 4., 5.] + dense_shape = [3, 4] + mask = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[1., 2., 3., 4., 5.]) + + x = paddle.rand([3, 5]) + y = paddle.rand([5, 4]) + + out = paddle.incubate.sparse.masked_matmul(x, y, mask) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[0.98986477, 0.97800624, 1.14591956, 0.68561077, 0.94714981]) """ return _C_ops.final_state_sparse_masked_matmul(x, y, mask)