diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index fcde699e71dfd..28f35535bebeb 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -266,6 +266,17 @@ layout : x backward : values_grad +- api: addmm + args : (Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) + output : Tensor(out) + kernel : + func : addmm_csr_dense {dense, sparse_csr, dense -> dense}, + addmm_csr_csr {sparse_csr, sparse_csr, sparse_csr -> sparse_csr}, + addmm_coo_dense {dense, sparse_coo, dense -> dense}, + addmm_coo_coo {sparse_coo, sparse_coo, sparse_coo -> sparse_coo} + layout : x + backward: addmm_grad + - api: coalesce args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index ab0070840f7fd..a39577e7e677c 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -30,6 +30,16 @@ func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}, add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr} +- backward_api : addmm_grad + forward : addmm(Tensor input, Tensor x, Tensor y, float alpha=1.0, float beta=1.0) -> Tensor(out) + args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha=1.0, float beta=1.0) + output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad) + kernel : + func : addmm_csr_dense_grad {dense, sparse_csr, dense, dense -> dense, sparse_csr, dense}, + addmm_csr_csr_grad {sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr, sparse_csr}, + addmm_coo_dense_grad {dense, sparse_coo, dense, dense -> dense, sparse_coo, dense}, + addmm_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo, sparse_coo} + - backward_api : asin_grad forward : asin(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) diff --git a/paddle/phi/kernels/sparse/addmm_grad_kernel.h b/paddle/phi/kernels/sparse/addmm_grad_kernel.h new file mode 100644 index 0000000000000..e320ba954139c --- /dev/null +++ b/paddle/phi/kernels/sparse/addmm_grad_kernel.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement Backward of " COO + COO @ COO -> COO" +template +void AddmmCooCooGradKernel(const Context& dev_ctx, + const SparseCooTensor& input, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + float alpha, + float beta, + SparseCooTensor* dinput, + SparseCooTensor* dx, + SparseCooTensor* dy); + +// Backward of "DENSE + COO @ DENSE -> DENSE" +template +void AddmmCooDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCooTensor* dx, + DenseTensor* dy); + +// TODO(zhouwei25): implement Backward of " CSR + CSR @ CSR -> CSR" +template +void AddmmCsrCsrGradKernel(const Context& dev_ctx, + const SparseCsrTensor& input, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + float alpha, + float beta, + SparseCsrTensor* dinput, + SparseCsrTensor* dx, + SparseCsrTensor* dy); + +/* Backward of "DENSE + CSR @ DENSE -> DENSE" */ +template +void AddmmCsrDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCsrTensor* dx, + DenseTensor* dy); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/addmm_kernel.h b/paddle/phi/kernels/sparse/addmm_kernel.h new file mode 100644 index 0000000000000..3cf21fbca2f81 --- /dev/null +++ b/paddle/phi/kernels/sparse/addmm_kernel.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement " COO + COO @ COO -> COO" +template +void AddmmCooCooKernel(const Context& dev_ctx, + const SparseCooTensor& input, + const SparseCooTensor& x, + const SparseCooTensor& y, + float alpha, + float beta, + SparseCooTensor* out); + +/* DENSE + COO @ DENSE -> DENSE */ +template +void AddmmCooDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out); + +// TODO(zhouwei25): implement " CSR + CSR @ CSR -> CSR" +template +void AddmmCsrCsrKernel(const Context& dev_ctx, + const SparseCsrTensor& input, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + float alpha, + float beta, + SparseCsrTensor* out); + +/* DENSE + CSR @ DENSE -> DENSE */ +template +void AddmmCsrDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/addmm_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/addmm_grad_kernel.cc new file mode 100644 index 0000000000000..1ed96bf28c5d0 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/addmm_grad_kernel.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/addmm_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void AddmmCooDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCooTensor* dx, + DenseTensor* dy) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of 'sparse.addmm' now.")); +} + +template +void AddmmCsrDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCsrTensor* dx, + DenseTensor* dy) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of 'sparse.addmm' now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(addmm_coo_dense_grad, + CPU, + ALL_LAYOUT, + phi::sparse::AddmmCooDenseGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(addmm_csr_dense_grad, + CPU, + ALL_LAYOUT, + phi::sparse::AddmmCsrDenseGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/cpu/addmm_kernel.cc b/paddle/phi/kernels/sparse/cpu/addmm_kernel.cc new file mode 100644 index 0000000000000..e58d9d0e69196 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/addmm_kernel.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/addmm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +namespace phi { +namespace sparse { + +/* DENSE + COO @ DENSE -> DENSE */ +template +void AddmmCooDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU kernel of 'sparse.addmm' now.")); +} + +/* DENSE + CSR @ DENSE -> DENSE */ +template +void AddmmCsrDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU kernel of 'sparse.addmm' now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(addmm_coo_dense, + CPU, + ALL_LAYOUT, + phi::sparse::AddmmCooDenseKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(addmm_csr_dense, + CPU, + ALL_LAYOUT, + phi::sparse::AddmmCsrDenseKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc index 2586976b7636c..5811880249a6d 100644 --- a/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc @@ -29,7 +29,7 @@ void MatmulCsrDenseGradKernel(const Context& dev_ctx, SparseCsrTensor* dx, DenseTensor* dy) { PADDLE_THROW(phi::errors::Unimplemented( - "Not support CPU backward kernel of Sparse Matmul now.")); + "Not support CPU backward kernel of 'sparse.matmul' now.")); } // TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" @@ -41,7 +41,7 @@ void MaskedMatmulCsrGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* dy) { PADDLE_THROW(phi::errors::Unimplemented( - "Not support CPU backward kernel of Matmul Mask As Sparse now.")); + "Not support CPU backward kernel of 'sparse.masked_matmul' now.")); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/gpu/addmm_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/addmm_grad_kernel.cu new file mode 100644 index 0000000000000..1f907415aacbc --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/addmm_grad_kernel.cu @@ -0,0 +1,96 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/addmm_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h" + +namespace phi { +namespace sparse { + +template +void AddmmCooDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCooTensor* dx, + DenseTensor* dy) { + auto blas = funcs::GetBlas(dev_ctx); + if (dinput) { + dinput->Resize(input.dims()); + dev_ctx.template Alloc(dinput); + + blas.VCOPY(input.numel(), dout.data(), dinput->data()); + blas.SCAL(input.numel(), beta, dinput->data()); + } + DenseTensor dout_scale = phi::EmptyLike(dev_ctx, dout); + blas.VCOPY(dout.numel(), dout.data(), dout_scale.data()); + blas.SCAL(dout.numel(), alpha, dout_scale.data()); + MatmulCooDenseGradKernel(dev_ctx, x, y, dout_scale, dx, dy); +} + +// Backward of "DENSE + CSR @ DENSE -> DENSE" +template +void AddmmCsrDenseGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + float alpha, + float beta, + DenseTensor* dinput, + SparseCsrTensor* dx, + DenseTensor* dy) { + auto blas = funcs::GetBlas(dev_ctx); + if (dinput) { + dinput->Resize(input.dims()); + dev_ctx.template Alloc(dinput); + + blas.VCOPY(input.numel(), dout.data(), dinput->data()); + blas.SCAL(input.numel(), beta, dinput->data()); + } + DenseTensor dout_scale = phi::EmptyLike(dev_ctx, dout); + blas.VCOPY(dout.numel(), dout.data(), dout_scale.data()); + blas.SCAL(dout.numel(), alpha, dout_scale.data()); + MatmulCsrDenseGradKernel(dev_ctx, x, y, dout_scale, dx, dy); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(addmm_coo_dense_grad, + GPU, + ALL_LAYOUT, + phi::sparse::AddmmCooDenseGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(addmm_csr_dense_grad, + GPU, + ALL_LAYOUT, + phi::sparse::AddmmCsrDenseGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu new file mode 100644 index 0000000000000..3e5d423b9f96a --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu @@ -0,0 +1,146 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/addmm_kernel.h" + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" + +namespace phi { +namespace sparse { + +template +void AddmmKernelImpl(const Context& dev_ctx, + const DenseTensor& input, + const TensorType& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out) { +#if CUDA_VERSION >= 11000 + std::vector input_dim = phi::vectorize(input.dims()); + std::vector x_dim = phi::vectorize(x.dims()); + std::vector y_dim = phi::vectorize(y.dims()); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE( + rank, + 2, + phi::errors::InvalidArgument( + "the dims size of input must be greater than or eaqual to 2.")); + + PADDLE_ENFORCE_EQ( + x_dim.size(), + rank, + phi::errors::PreconditionNotMet( + "The dims size of Input(input) and Input(x) must be eaqual.")); + + PADDLE_ENFORCE_GE( + y_dim.size(), + rank, + phi::errors::InvalidArgument( + "the dims size of Input(input) and Input(y) must be eaqual.")); + + for (size_t i = 0; i < rank - 2; ++i) { + PADDLE_ENFORCE_EQ(input_dim[i], + x_dim[i], + phi::errors::InvalidArgument( + "input.dim[%d] and x.dim[%d] must be eaqul.", i, i)); + PADDLE_ENFORCE_EQ(input_dim[i], + y_dim[i], + phi::errors::InvalidArgument( + "input.dim[%d] and y.dim[%d] must be eaqul.", i, i)); + } + + PADDLE_ENFORCE_GE( + input_dim[rank - 2], + x_dim[rank - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(input) and Input(x) is not suitable for matmul " + "opetation, input_dim[-2] must be eaqual to x_dim[-2].")); + + PADDLE_ENFORCE_GE( + input_dim[rank - 1], + y_dim[rank - 1], + phi::errors::PreconditionNotMet( + "The shape of Input(input) and Input(y) is not suitable for matmul " + "opetation, input_dim[-1] must be eaqual to y_dim[-1].")); + + PADDLE_ENFORCE_GE( + x_dim[rank - 1], + y_dim[rank - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, x_dim[-1] must be eaqual to y_dim[-2].")); + + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, out); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SPMM( + false, false, static_cast(alpha), x, y, static_cast(beta), out); +#else + PADDLE_THROW( + phi::errors::Unimplemented("forward of 'sparse.addmm' use cusparseSpMM, " + "which is supported from CUDA 11.0")); +#endif +} + +template +void AddmmCooDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCooTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out) { + AddmmKernelImpl(dev_ctx, input, x, y, alpha, beta, out); +} + +template +void AddmmCsrDenseKernel(const Context& dev_ctx, + const DenseTensor& input, + const SparseCsrTensor& x, + const DenseTensor& y, + float alpha, + float beta, + DenseTensor* out) { + AddmmKernelImpl(dev_ctx, input, x, y, alpha, beta, out); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(addmm_coo_dense, + GPU, + ALL_LAYOUT, + phi::sparse::AddmmCooDenseKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(addmm_csr_dense, + GPU, + ALL_LAYOUT, + phi::sparse::AddmmCsrDenseKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/python/paddle/fluid/tests/unittests/test_sparse_addmm_op.py b/python/paddle/fluid/tests/unittests/test_sparse_addmm_op.py new file mode 100644 index 0000000000000..458ee25e41075 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_addmm_op.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import scipy +import scipy.sparse as sp +import unittest +import os +import re + +paddle.set_default_dtype('float64') + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +class TestAddmm(unittest.TestCase): + # input: dense, x: sparse, y: dense, out: dense + def check_result(self, input_shape, x_shape, y_shape, format): + if len(x_shape) == 3: + mask = paddle.randint(0, 2, [x_shape[-2], x_shape[-1]]) + else: + mask = paddle.randint(0, 2, x_shape) + + origin_input = paddle.rand(input_shape) + origin_x = paddle.rand(x_shape) * mask + origin_y = paddle.rand(y_shape) + + dense_input = origin_input.detach() + dense_input.stop_gradient = False + dense_x = origin_x.detach() + dense_x.stop_gradient = False + dense_y = origin_y.detach() + dense_y.stop_gradient = False + dense_out = 2. * paddle.matmul(dense_x, dense_y) + 3. * dense_input + + sp_input = dense_input.detach() + sp_input.stop_gradient = False + if format == "coo": + sp_x = origin_x.detach().to_sparse_coo(len(x_shape)) + else: + sp_x = origin_x.detach().to_sparse_csr() + sp_x.stop_gradient = False + sp_y = origin_y.detach() + sp_y.stop_gradient = False + sp_out = paddle.incubate.sparse.addmm(sp_input, sp_x, sp_y, 3.0, 2.0) + + self.assertTrue(np.allclose(sp_out.numpy(), dense_out.numpy())) + if get_cuda_version() >= 11030: + dense_out.backward() + sp_out.backward() + self.assertTrue( + np.allclose(sp_input.grad.numpy(), dense_input.grad.numpy())) + self.assertTrue( + np.allclose(sp_x.grad.to_dense().numpy(), + (dense_x.grad * mask).numpy())) + self.assertTrue(np.allclose(sp_y.grad.numpy(), + dense_y.grad.numpy())) + + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11000, "only support cuda>=11.0") + def test_addmm_2d(self): + self.check_result([16, 10], [16, 12], [12, 10], 'coo') + self.check_result([16, 10], [16, 12], [12, 10], 'csr') + + @unittest.skipIf(not paddle.is_compiled_with_cuda() + or get_cuda_version() < 11070, "only support cuda>=11.7") + def test_addmm_3d(self): + self.check_result([8, 16, 10], [8, 16, 12], [8, 12, 10], 'coo') + self.check_result([8, 16, 10], [8, 16, 12], [8, 12, 10], 'csr') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 47c7a312e24d8..6a672cb49415c 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -40,6 +40,8 @@ from .binary import multiply from .binary import subtract +from .multiary import addmm + from . import nn __all__ = [ @@ -63,6 +65,7 @@ 'mv', 'matmul', 'masked_matmul', + 'addmm', 'add', 'subtract', 'multiply', diff --git a/python/paddle/incubate/sparse/multiary.py b/python/paddle/incubate/sparse/multiary.py new file mode 100644 index 0000000000000..17cf75fdc3903 --- /dev/null +++ b/python/paddle/incubate/sparse/multiary.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _C_ops +from paddle.fluid.framework import dygraph_only + +__all__ = [] + + +@dygraph_only +def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): + """ + Note: + This API is only supported from ``CUDA 11.0`` . + + Applies matrix multiplication for `x` and `y` , `input` is added to + the final result. The equation is: + + .. math:: + + Out = alpha * x * y + beta * input + + The supported input/output Tensor layout are as follows: + + Note: + input[SparseCsrTensor] + x[SparseCsrTensor] @ y[SparseCsrTensor] -> out[SparseCsrTensor] + input[DenseTensor] + x[SparseCsrTensor] @ y[DenseTensor] -> out[DenseTensor] + input[SparseCooTensor] + x[SparseCooTensor] @ y[SparseCooTensor] -> out[SparseCooTensor] + input[DenseTensor] + x[SparseCooTensor] @ y[DenseTensor] -> out[DenseTensor] + + It supports backward propagation. + + Dimensions `input` , `x` , `y` must be same and >= 2D. Automatic broadcasting of Tensor is not supported. + + Args: + input (Tensor): The input tensor. Shape is [*, M, N]. The data type can be float32 or float64. + x (Tensor): The input tensor. Shape is [*, M, K]. The data type can be float32 or float64. + y (Tensor): The input tensor. Shape is [*, K, N]. The data type can be float32 or float64. + beta (float, optional): Coefficient of `input` . Default: 1.0 + alpha (float, optional): Coefficient of `x * y` . Default: 1.0 + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Its layout is determined by that of `x` and `y` . dtype and shape is the same with `input` + + Examples: + + .. code-block:: python + + import paddle + + # dense + csr @ dense -> dense + input = paddle.rand([3, 2]) + crows = [0, 1, 2, 3] + cols = [1, 2, 0] + values = [1., 2., 3.] + x = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, [3, 3]) + y = paddle.rand([3, 2]) + out = paddle.incubate.sparse.addmm(input, x, y, 3.0, 2.0) + + # dense + coo @ dense -> dense + input = paddle.rand([3, 2]) + indices = [[0, 1, 2], [1, 2, 0]] + values = [1., 2., 3.] + x = paddle.incubate.sparse.sparse_coo_tensor(indices, values, [3, 3]) + y = paddle.rand([3, 2]) + out = paddle.incubate.sparse.addmm(input, x, y, 3.0, 2.0) + + """ + return _C_ops.final_state_sparse_addmm(input, x, y, alpha, beta)