diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 542aad70cb696..1cba94339bfdf 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -27,6 +27,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" /** * Implementation of GradNodeBase, Edge and GradTensorHolder. @@ -114,6 +115,10 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out, phi::SparseCooTensor* coo_tensor = static_cast(fwd_out.impl().get()); dense_tensor = coo_tensor->mutable_non_zero_elements(); + } else if (phi::SparseCsrTensor::classof(fwd_out.impl().get())) { + phi::SparseCsrTensor* csr_tensor = + static_cast(fwd_out.impl().get()); + dense_tensor = csr_tensor->mutable_non_zero_elements(); } else { VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " "non-DenseTensor argument."; diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 9de13f6deab74..ee5dd622412e1 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -66,8 +66,17 @@ void GradTensorHolder::CopyValueFromTensor( // Create new tensor->impl and fill it with 1.0 if (t.defined()) { // Fill 1.0, use full to support complex, one_like don't support it. - buffer_[slot_id][rank] = - paddle::experimental::full(t.shape(), 1, t.dtype(), t.place()); + if (t.is_dense_tensor()) { + buffer_[slot_id][rank] = + paddle::experimental::full(t.shape(), 1, t.dtype(), t.place()); + } else if (t.is_sparse_csr_tensor() || t.is_sparse_coo_tensor()) { + buffer_[slot_id][rank] = + paddle::experimental::sparse::full_like(t, 1, t.dtype()); + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Only Support DENSE_TENSOR, SPARSE_COO_TENSOR, SPARSE_CSR_TENSOR " + "now.")); + } egr::EagerUtils::autograd_meta(&(buffer_[slot_id][rank])) ->SetStopGradient(false); } diff --git a/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h b/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h index ff141f5f1aed0..1ee638242e112 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h @@ -31,7 +31,7 @@ class CusparseHandleHolder { // ROCM is not yet supported #if defined(PADDLE_WITH_CUDA) // The generic APIs is supported from CUDA10.1 -#if CUDA_VERSION >= 10010 +#if CUDA_VERSION >= 11000 PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&handle_)); PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(handle_, stream)); #endif @@ -41,7 +41,7 @@ class CusparseHandleHolder { ~CusparseHandleHolder() PADDLE_MAY_THROW { #if defined(PADDLE_WITH_CUDA) -#if CUDA_VERSION >= 10010 +#if CUDA_VERSION >= 11000 PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle_)); #endif #endif diff --git a/paddle/fluid/platform/dynload/cusparse.cc b/paddle/fluid/platform/dynload/cusparse.cc index 998437997547b..da93455e8bc7d 100644 --- a/paddle/fluid/platform/dynload/cusparse.cc +++ b/paddle/fluid/platform/dynload/cusparse.cc @@ -24,10 +24,6 @@ namespace dynload { CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); #endif -#ifdef CUSPARSE_ROUTINE_EACH_11020 -CUSPARSE_ROUTINE_EACH_11020(DEFINE_WRAP); -#endif - #ifdef CUSPARSE_ROUTINE_EACH_R2 CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); #endif diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 925852bb4158b..7f29ec0e823a4 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -29,23 +29,17 @@ namespace dynload { extern DynLoad__##__name __name #if defined(PADDLE_WITH_CUDA) -// The generic APIs is supported from CUDA10.1 -#if CUDA_VERSION >= 10010 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); - -CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); - -// APIs available after CUDA 11.2 -#if CUDA_VERSION >= 11020 -#define CUSPARSE_ROUTINE_EACH_11020(__macro) \ +// APIs available after CUDA 11.0 +#if CUDA_VERSION >= 11000 +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ __macro(cusparseCreateCsr); \ __macro(cusparseCreateCoo); \ __macro(cusparseCreateDnMat); \ @@ -59,11 +53,13 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); __macro(cusparseDenseToSparse_analysis); \ __macro(cusparseDenseToSparse_convert); \ __macro(cusparseSparseToDense_bufferSize); \ - __macro(cusparseSparseToDense); + __macro(cusparseSparseToDense); \ + __macro(cusparseDnMatSetStridedBatch); \ + __macro(cusparseCsrSetStridedBatch); -CUSPARSE_ROUTINE_EACH_11020(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +#endif -// APIs available after CUDA 11.3 #if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ __macro(cusparseSDDMM_bufferSize); \ @@ -72,8 +68,7 @@ CUSPARSE_ROUTINE_EACH_11020(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) CUSPARSE_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif -#endif -#endif + #endif #undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP diff --git a/paddle/phi/backends/dynload/cusparse.cc b/paddle/phi/backends/dynload/cusparse.cc index 326645726bbed..013211064b8e4 100644 --- a/paddle/phi/backends/dynload/cusparse.cc +++ b/paddle/phi/backends/dynload/cusparse.cc @@ -26,10 +26,6 @@ void *cusparse_dso_handle; CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); #endif -#ifdef CUSPARSE_ROUTINE_EACH_11020 -CUSPARSE_ROUTINE_EACH_11020(DEFINE_WRAP); -#endif - #ifdef CUSPARSE_ROUTINE_EACH_R2 CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); #endif diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index a7e305f98d49a..6160faf1f422d 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -30,34 +30,28 @@ extern void *cusparse_dso_handle; struct DynLoad__##__name { \ template \ cusparseStatus_t operator()(Args... args) { \ - using cusparseFunc = decltype(&::__name); \ + using Func = decltype(&::__name); \ std::call_once(cusparse_dso_flag, []() { \ cusparse_dso_handle = phi::dynload::GetCusparseDsoHandle(); \ }); \ static void *p_##__name = dlsym(cusparse_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name #if defined(PADDLE_WITH_CUDA) -// The generic APIs is supported from CUDA10.1 -#if CUDA_VERSION >= 10010 -#define CUSPARSE_ROUTINE_EACH(__macro) \ - __macro(cusparseCreate); \ - __macro(cusparseSetStream); \ - __macro(cusparseCreateMatDescr); \ - __macro(cusparseDestroy); \ - __macro(cusparseSnnz); \ - __macro(cusparseDnnz); \ - __macro(cusparseSetMatType); \ - __macro(cusparseSetMatIndexBase); - -CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); - -// APIs available after CUDA 11.2 -#if CUDA_VERSION >= 11020 -#define CUSPARSE_ROUTINE_EACH_11020(__macro) \ +// APIs available after CUDA 11.0 +#if CUDA_VERSION >= 11000 +#define CUSPARSE_ROUTINE_EACH(__macro) \ + __macro(cusparseCreate); \ + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); \ __macro(cusparseCreateCsr); \ __macro(cusparseCreateCoo); \ __macro(cusparseCreateDnMat); \ @@ -71,11 +65,13 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); __macro(cusparseDenseToSparse_analysis); \ __macro(cusparseDenseToSparse_convert); \ __macro(cusparseSparseToDense_bufferSize); \ - __macro(cusparseSparseToDense); + __macro(cusparseSparseToDense); \ + __macro(cusparseDnMatSetStridedBatch); \ + __macro(cusparseCsrSetStridedBatch); -CUSPARSE_ROUTINE_EACH_11020(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +#endif -// APIs available after CUDA 11.3 #if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ __macro(cusparseSDDMM_bufferSize); \ @@ -84,8 +80,7 @@ CUSPARSE_ROUTINE_EACH_11020(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif -#endif -#endif + #endif #undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index ead53d648109d..92c1fedae44af 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -402,7 +402,10 @@ struct GPUContext::Impl { void SetSolverHandle(solverHandle_t handle) { solver_handle_ = handle; } - sparseHandle_t GetSparseHandle() const { + sparseHandle_t GetSparseHandle() { + std::call_once(flag_sparse_, [=]() { + if (!sparse_handle_) phi::InitSparseHandle(&sparse_handle_, stream_); + }); PD_CHECK(sparse_handle_ != nullptr, "the gpu sparse handle is nullptr."); return sparse_handle_; } @@ -519,7 +522,12 @@ struct GPUContext::Impl { } inline void CusparseCall( - const std::function& callback) const { + const std::function& callback) { + std::call_once(flag_sparse_, [=]() { + if (!sparse_handle_) { + phi::InitSparseHandle(&sparse_handle_, stream_); + } + }); std::lock_guard guard(sparse_mtx_); callback(sparse_handle_); } @@ -598,6 +606,7 @@ struct GPUContext::Impl { sparseHandle_t sparse_handle_{nullptr}; DnnWorkspaceHandle* workspace_{nullptr}; + std::once_flag flag_sparse_; std::once_flag flag_blas_; std::once_flag flag_blaslt_; std::once_flag flag_dnn_; diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index 268024eb25949..0257139914384 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -250,7 +250,7 @@ void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) { // ROCM is not yet supported #if defined(PADDLE_WITH_CUDA) // The generic APIs is supported from CUDA10.1 -#if CUDA_VERSION >= 10010 +#if CUDA_VERSION >= 11000 PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(handle)); PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(*handle, stream)); #endif @@ -259,7 +259,7 @@ void InitSparseHandle(sparseHandle_t* handle, gpuStream_t stream) { void DestroySparseHandle(sparseHandle_t handle) { #ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 10010 +#if CUDA_VERSION >= 11000 if (handle != nullptr) { PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle)); handle = nullptr; diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 7e14cad242d12..0da69ee7ed16c 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -85,6 +85,10 @@ class SparseCsrTensor : public TensorBase, /// \return The non zero elemetns in original dense tensor. const DenseTensor& non_zero_elements() const { return non_zero_elements_; } + /// \brief Returns the total number of non zero elements in original dense + /// tensor. + int64_t nnz() const { return non_zero_elements_.numel(); } + /// \brief Return the number of elements contained in original dense tensor /// \return The number of elements contained in original dense tensor int64_t numel() const override { return product(dims_); } diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas.h b/paddle/phi/kernels/funcs/sparse/sparse_blas.h new file mode 100644 index 0000000000000..edad70edd74be --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas.h @@ -0,0 +1,96 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace funcs { +namespace sparse { + +template +class SparseBlas { + public: + explicit SparseBlas(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {} + + // TODO(zhouwei25): implement "COO @ DENSE -> DENSE" of DSDMM + template + void DSDMM(bool transa, + bool transb, + T alpha, + const phi::SparseCooTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_c) const; + + template + void DSDMM(bool transa, + bool transb, + T alpha, + const phi::SparseCsrTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_c) const; + + template + void SDDMM(bool transa, + bool transb, + T alpha, + const phi::DenseTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::SparseCsrTensor* mat_c) const; + + private: + const DeviceContext& dev_ctx_; +}; + +template +class SparseBlasT : private SparseBlas { + public: + using SparseBlas::SparseBlas; + + template + void DSDMM(ARGS... args) const { + Base()->template DSDMM(args...); + } + + template + void SDDMM(ARGS... args) const { + Base()->template SDDMM(args...); + } + + private: + const SparseBlas* Base() const { + return static_cast*>(this); + } +}; + +template +inline SparseBlasT GetSparseBlas( + const DeviceContext& dev_ctx) { + return SparseBlasT(dev_ctx); +} + +} // namespace sparse +} // namespace funcs +} // namespace phi + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11000 +#include "paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h" +#endif diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h new file mode 100644 index 0000000000000..0c54f99bef80b --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -0,0 +1,279 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/phi/backends/dynload/cusparse.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/core/visit_type.h" + +namespace phi { +namespace funcs { +namespace sparse { + +template +cudaDataType_t GetGpuDataType() { + if (std::is_same::value) { + return CUDA_R_32F; + } else if (std::is_same::value) { + return CUDA_R_64F; + } else if (std::is_same::value) { + return CUDA_R_16F; + } +} + +inline cusparseOperation_t GetTransposeOperation(const bool trans) { + if (trans) { + return CUSPARSE_OPERATION_TRANSPOSE; + } else { + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } +} + +template +class CuSparseSpMatDescriptor { + public: + explicit CuSparseSpMatDescriptor(const phi::SparseCsrTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + PD_VISIT_INTEGRAL_TYPES( + x.non_zero_crows().dtype(), "CuSparseSpMatDescriptor", ([&] { + const data_t* crows_data = x.non_zero_crows().data(); + const data_t* cols_data = x.non_zero_cols().data(); + const T* values_data = x.non_zero_elements().data(); + int64_t nnz = x.nnz(); + + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + + cudaDataType_t gpu_type = GetGpuDataType(); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCreateCsr(&descriptor_, + M, + N, + nnz, + const_cast(crows_data), + const_cast(cols_data), + const_cast(values_data), + CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_BASE_ZERO, + gpu_type); + }); + PADDLE_ENFORCE_EQ(x.non_zero_crows().numel(), batch_size * (M + 1)); + PADDLE_ENFORCE_EQ(x.non_zero_cols().numel(), x.nnz()); + if (batch_size > 1) { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCsrSetStridedBatch( + descriptor_, batch_size, M + 1, nnz); + }); + } + })); + + VLOG(6) << "Create cusparseSpMatDescr_t " << &descriptor_; + } + + ~CuSparseSpMatDescriptor() { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseDestroySpMat(descriptor_); + }); + VLOG(6) << "Destroy cusparseSpMatDescr_t " << &descriptor_; + } + + const cusparseSpMatDescr_t& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + cusparseSpMatDescr_t descriptor_; +}; + +template +class CuSparseDnMatDescriptor { + public: + explicit CuSparseDnMatDescriptor(const phi::DenseTensor& x, + const phi::GPUContext& dev_ctx) + : dev_ctx_(dev_ctx) { + const T* x_data = x.data(); + std::vector xdim_vec = phi::vectorize(x.dims()); + auto x_ndims = xdim_vec.size(); + int64_t M = xdim_vec[x_ndims - 2]; + int64_t N = xdim_vec[x_ndims - 1]; + int batch_size = 1; + for (int i = 0; i < x_ndims - 2; i++) { + batch_size *= xdim_vec[i]; + } + + cudaDataType_t gpu_type = GetGpuDataType(); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseCreateDnMat(&descriptor_, + M, + N, + N, + const_cast(x_data), + gpu_type, + CUSPARSE_ORDER_ROW); + }); + + PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N); + if (batch_size > 1) { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseDnMatSetStridedBatch( + descriptor_, batch_size, M * N); + }); + } + VLOG(6) << "Create cusparseDnMatDescr_t " << &descriptor_; + } + + ~CuSparseDnMatDescriptor() { + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseDestroyDnMat(descriptor_); + }); + VLOG(6) << "Destroy cusparseDnMatDescr_t " << &descriptor_; + } + + const cusparseDnMatDescr_t& descriptor() const { return descriptor_; } + + private: + const phi::GPUContext& dev_ctx_; + cusparseDnMatDescr_t descriptor_; +}; + +template <> +template +void SparseBlas::DSDMM(bool transa, + bool transb, + T alpha, + const phi::SparseCsrTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::DenseTensor* mat_c) const { + cudaDataType_t gpu_type = GetGpuDataType(); + + auto a_descriptor = CuSparseSpMatDescriptor(mat_a, dev_ctx_); + auto b_descriptor = CuSparseDnMatDescriptor(mat_b, dev_ctx_); + auto c_descriptor = CuSparseDnMatDescriptor(*mat_c, dev_ctx_); + + size_t buffer_size = 0; + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpMM_bufferSize(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + c_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPMM_ALG_DEFAULT, + &buffer_size); + }); + + paddle::memory::allocation::AllocationPtr tmp_buffer = + paddle::memory::Alloc(dev_ctx_, buffer_size); + void* tmp_buffer_ptr = tmp_buffer->ptr(); + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSpMM(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + c_descriptor.descriptor(), + gpu_type, + CUSPARSE_SPMM_ALG_DEFAULT, + tmp_buffer_ptr); + }); +} + +#if CUDA_VERSION >= 11030 +template <> +template +void SparseBlas::SDDMM(bool transa, + bool transb, + T alpha, + const phi::DenseTensor& mat_a, + const phi::DenseTensor& mat_b, + T beta, + phi::SparseCsrTensor* mat_c) const { + cudaDataType_t gpu_type = GetGpuDataType(); + + auto a_descriptor = CuSparseDnMatDescriptor(mat_a, dev_ctx_); + auto b_descriptor = CuSparseDnMatDescriptor(mat_b, dev_ctx_); + auto c_descriptor = CuSparseSpMatDescriptor(*mat_c, dev_ctx_); + + size_t buffer_size = 0; + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSDDMM_bufferSize(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + c_descriptor.descriptor(), + gpu_type, + CUSPARSE_SDDMM_ALG_DEFAULT, + &buffer_size); + }); + + paddle::memory::allocation::AllocationPtr tmp_buffer = + paddle::memory::Alloc(dev_ctx_, buffer_size); + void* tmp_buffer_ptr = tmp_buffer->ptr(); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSDDMM_preprocess(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + c_descriptor.descriptor(), + gpu_type, + CUSPARSE_SDDMM_ALG_DEFAULT, + tmp_buffer_ptr); + }); + + dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { + phi::dynload::cusparseSDDMM(handle, + GetTransposeOperation(transa), + GetTransposeOperation(transb), + &alpha, + a_descriptor.descriptor(), + b_descriptor.descriptor(), + &beta, + c_descriptor.descriptor(), + gpu_type, + CUSPARSE_SDDMM_ALG_DEFAULT, + tmp_buffer_ptr); + }); +} +#endif + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/full_kernel.cc b/paddle/phi/kernels/sparse/cpu/full_kernel.cc new file mode 100644 index 0000000000000..3c8be16626202 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/full_kernel.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/full_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void FullValue(const Context& dev_ctx, DenseTensor* tensor, T val) { + dev_ctx.template Alloc(tensor); + auto t = phi::EigenVector::Flatten(*tensor); + t.device(*dev_ctx.eigen_device()) = t.constant(val); +} + +template +void CooFullLikeKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const Scalar& val, + DataType dtype, + SparseCooTensor* out) { + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_indices()); + + DenseTensor* values = out->mutable_non_zero_elements(); + values->Resize(x.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + FullValue(dev_ctx, values, val.to()); + + out->set_dims(x.dims()); +} + +template +void CsrFullLikeKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const Scalar& val, + DataType dtype, + SparseCsrTensor* out) { + phi::Copy(dev_ctx, + x.non_zero_crows(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_crows()); + + phi::Copy(dev_ctx, + x.non_zero_cols(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_cols()); + + DenseTensor* values = out->mutable_non_zero_elements(); + values->Resize(x.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + FullValue(dev_ctx, values, val.to()); + + out->set_dims(x.dims()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(coo_full_like, + CPU, + ALL_LAYOUT, + phi::CooFullLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(csr_full_like, + CPU, + ALL_LAYOUT, + phi::CsrFullLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc new file mode 100644 index 0000000000000..cd1665b66431b --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/matmul_grad_kernel.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement CPU backward kernel of " CSR @ DENSE -> DENSE" +template +void CsrDenseMatmulGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + SparseCsrTensor* dx, + DenseTensor* dy) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of Sparse Matmul now.")); +} + +// TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" +template +void CsrMaskedMatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU backward kernel of Matmul Mask As Sparse now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(csr_dense_matmul_grad, + CPU, + ALL_LAYOUT, + phi::sparse::CsrDenseMatmulGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(csr_masked_matmul_grad, + CPU, + ALL_LAYOUT, + phi::sparse::CsrMaskedMatmulGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc new file mode 100644 index 0000000000000..10ad848442c9d --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/matmul_kernel.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/matmul_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement CPU kernel of " CSR @ DENSE -> DENSE" +template +void CsrDenseMatmulKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + DenseTensor* out) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU kernel of Sparse Matmul now.")); +} + +// TODO(zhouwei25): implement CPU kernel of " DENSE @ DENSE * CSR_MASK -> CSR" +template +void CsrMaskedMatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& mask, + SparseCsrTensor* out) { + PADDLE_THROW(phi::errors::Unimplemented( + "Not support CPU kernel of Matmul Mask As Sparse now.")); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(csr_dense_matmul, + CPU, + ALL_LAYOUT, + phi::sparse::CsrDenseMatmulKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(csr_masked_matmul, + CPU, + ALL_LAYOUT, + phi::sparse::CsrMaskedMatmulKernel, + float, + double) {} diff --git a/paddle/phi/kernels/sparse/full_kernel.h b/paddle/phi/kernels/sparse/full_kernel.h new file mode 100644 index 0000000000000..8c84d43ff0219 --- /dev/null +++ b/paddle/phi/kernels/sparse/full_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { + +template +void CooFullLikeKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const Scalar& val, + DataType dtype, + SparseCooTensor* out); + +template +void CsrFullLikeKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const Scalar& val, + DataType dtype, + SparseCsrTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/full_kernel.cu b/paddle/phi/kernels/sparse/gpu/full_kernel.cu new file mode 100644 index 0000000000000..6c1e0160bca8e --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/full_kernel.cu @@ -0,0 +1,132 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/sparse/full_kernel.h" + +namespace phi { + +template +struct FullFuctor { + OutT value; + + template + explicit inline FullFuctor(VType val) { + value = static_cast(val); + } + + __device__ __forceinline__ OutT operator()() const { + return static_cast(value); + } +}; + +template +void CooFullLikeKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const Scalar& val, + DataType dtype, + SparseCooTensor* out) { + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_indices()); + + DenseTensor* values = out->mutable_non_zero_elements(); + values->Resize(x.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + + std::vector inputs = {}; + std::vector outputs = {values}; + int numel = values->numel(); + if (numel > 0) { + phi::funcs::ElementwiseKernel( + dev_ctx, inputs, &outputs, FullFuctor(val.to())); + } + out->set_dims(x.dims()); +} + +template +void CsrFullLikeKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const Scalar& val, + DataType dtype, + SparseCsrTensor* out) { + phi::Copy(dev_ctx, + x.non_zero_crows(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_crows()); + + phi::Copy(dev_ctx, + x.non_zero_cols(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_cols()); + + DenseTensor* values = out->mutable_non_zero_elements(); + values->Resize(x.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + + std::vector inputs = {}; + std::vector outputs = {values}; + int numel = values->numel(); + if (numel > 0) { + phi::funcs::ElementwiseKernel( + dev_ctx, inputs, &outputs, FullFuctor(val.to())); + } + out->set_dims(x.dims()); +} + +} // namespace phi + +PD_REGISTER_KERNEL(coo_full_like, + GPU, + ALL_LAYOUT, + phi::CooFullLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + +PD_REGISTER_KERNEL(csr_full_like, + GPU, + ALL_LAYOUT, + phi::CsrFullLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu new file mode 100644 index 0000000000000..30344c307541e --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu @@ -0,0 +1,149 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { +namespace sparse { + +template +void CsrDenseMatmulGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + SparseCsrTensor* dx, + DenseTensor* dy) { +#if CUDA_VERSION >= 11030 + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + + // dx{SparseCsr} = dout{Dense} * y'{Dense} + if (dx) { + // InferMeta of SparseCsrTensor 'dx' + dx->set_dims(x.dims()); + + phi::Copy(dev_ctx, + x.non_zero_crows(), + dev_ctx.GetPlace(), + false, + dx->mutable_non_zero_crows()); + phi::Copy(dev_ctx, + x.non_zero_cols(), + dev_ctx.GetPlace(), + false, + dx->mutable_non_zero_cols()); + + DenseTensor* values = dx->mutable_non_zero_elements(); + values->Resize(x.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + + sparse_blas.SDDMM( + false, true, static_cast(1), dout, y, static_cast(0), dx); + } + + // dy{Dense} = x'{SparseCsr} * dout{Dense} + if (dy) { + // InferMeta of DenseTensor 'dy' + MetaTensor meta_dy(dy); + meta_dy.set_dims(y.dims()); + meta_dy.set_dtype(y.dtype()); + + dev_ctx.template Alloc(dy); + + sparse_blas.DSDMM( + true, false, static_cast(1), x, dout, static_cast(0), dy); + } +#else + PADDLE_THROW(phi::errors::Unimplemented( + " backward of 'sparse.mm' use cusparseSDDMM, Only " + "support it from CUDA 11.3")); +#endif +} + +template +void CsrMaskedMatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { +#if CUDA_VERSION >= 11000 + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + + // dx{Dense} = dout{SparseCsr} * y'{Dense} + if (dx) { + // InferMeta of DenseTensor 'dx' + MetaTensor meta_dx(dx); + meta_dx.set_dims(x.dims()); + meta_dx.set_dtype(x.dtype()); + + dev_ctx.template Alloc(dx); + sparse_blas.DSDMM( + false, true, static_cast(1), dout, y, static_cast(0), dx); + } + + // dy{Dense} = x'{Dense} * dout{SparseCsr} + // That is: dy'{Dense} = dout'{SparseCsr} * x{Dense} + if (dy) { + std::vector trans_dim_vec = phi::vectorize(y.dims()); + size_t rank = trans_dim_vec.size(); + std::swap(trans_dim_vec[rank - 1], trans_dim_vec[rank - 2]); + DenseTensor trans_dy = phi::Empty(dev_ctx, trans_dim_vec); + + sparse_blas.DSDMM( + true, false, static_cast(1), dout, x, static_cast(0), &trans_dy); + + // InferMeta of DenseTensor 'dy' + MetaTensor meta_dy(dy); + meta_dy.set_dims(y.dims()); + meta_dy.set_dtype(y.dtype()); + + dev_ctx.template Alloc(dy); + + size_t y_ndim = y.dims().size(); + std::vector axis(y_ndim); + for (size_t i = 0; i < y_ndim; ++i) { + axis[i] = i; + } + std::swap(axis[y_ndim - 1], axis[y_ndim - 2]); + TransposeKernel(dev_ctx, trans_dy, axis, dy); + } +#endif +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(csr_dense_matmul_grad, + GPU, + ALL_LAYOUT, + phi::sparse::CsrDenseMatmulGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(csr_masked_matmul_grad, + GPU, + ALL_LAYOUT, + phi::sparse::CsrMaskedMatmulGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu new file mode 100644 index 0000000000000..f8590737eda07 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -0,0 +1,207 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/matmul_kernel.h" + +namespace phi { +namespace sparse { + +template +void CsrDenseMatmulKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + DenseTensor* out) { +#if CUDA_VERSION >= 11000 + std::vector xdim_vec = phi::vectorize(x.dims()); + std::vector ydim_vec = phi::vectorize(y.dims()); + auto x_ndims = xdim_vec.size(); + auto y_ndims = ydim_vec.size(); + PADDLE_ENFORCE_EQ( + x_ndims, + y_ndims, + phi::errors::PreconditionNotMet("The dims size of Input(x) and Input(y) " + "should be equal, But received X's " + "dimensions=%d, Y's dimensions=%d.", + x_ndims, + y_ndims)); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dims size of Input(x) and " + "Input(y) must be greater than " + "or eaqual to 2.")); + + for (size_t i = 0; i < x_ndims - 2; ++i) { + PADDLE_ENFORCE_EQ(xdim_vec[i], + ydim_vec[i], + phi::errors::InvalidArgument( + "x.dim[%d] and x.dim[%d] must match.", i, i)); + } + + PADDLE_ENFORCE_GE( + xdim_vec[x_ndims - 1], + ydim_vec[y_ndims - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, x_dim[-1] must be eaqual to y_dim[-2].")); + + // InferMeta of DenseTensor 'out' + std::vector out_dim_vec(ydim_vec); + out_dim_vec[y_ndims - 2] = xdim_vec[x_ndims - 2]; + out_dim_vec[y_ndims - 1] = ydim_vec[y_ndims - 1]; + MetaTensor meta_out(out); + meta_out.set_dims(phi::make_ddim(out_dim_vec)); + meta_out.set_dtype(x.non_zero_elements().dtype()); + + dev_ctx.template Alloc(out); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.DSDMM( + false, false, static_cast(1), x, y, static_cast(0), out); +#else + PADDLE_THROW( + phi::errors::Unimplemented(" forward of 'sparse.mm' use cusparseSpMM, " + "which is supported from CUDA 11.0")); +#endif +} + +template +void CsrMaskedMatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& mask, + SparseCsrTensor* out) { +#if CUDA_VERSION >= 11030 + std::vector xdim_vec = phi::vectorize(x.dims()); + std::vector ydim_vec = phi::vectorize(y.dims()); + std::vector maskdim_vec = phi::vectorize(mask.dims()); + + auto x_ndims = xdim_vec.size(); + auto y_ndims = ydim_vec.size(); + auto mask_ndims = maskdim_vec.size(); + + PADDLE_ENFORCE_EQ( + x_ndims, + y_ndims, + phi::errors::PreconditionNotMet("The dims size of Input(x) and Input(y) " + "should be equal, But received X's " + "dimensions=%d, Y's dimensions=%d.", + x_ndims, + y_ndims)); + PADDLE_ENFORCE_EQ(x_ndims, + mask_ndims, + phi::errors::PreconditionNotMet( + "The dims size of Input(x) and Input(mask) " + "should be equal, But received X's " + "dimensions=%d, mask's dimensions=%d.", + x_ndims, + mask_ndims)); + PADDLE_ENFORCE_GE( + x_ndims, + 2, + phi::errors::InvalidArgument("the dims size of Input(x) and " + "Input(y) must be greater than " + "or eaqual to 2.")); + + for (size_t i = 0; i < x_ndims - 2; ++i) { + PADDLE_ENFORCE_EQ(xdim_vec[i], + ydim_vec[i], + phi::errors::InvalidArgument( + "x.dim[%d] and x.dim[%d] must match.", i, i)); + PADDLE_ENFORCE_EQ(xdim_vec[i], + maskdim_vec[i], + phi::errors::InvalidArgument( + "x.dim[%d] and mask.dim[%d] must match.", i, i)); + } + + PADDLE_ENFORCE_GE( + xdim_vec[x_ndims - 1], + ydim_vec[y_ndims - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, x_dim[-1] must be eaqual to y_dim[-2].")); + + PADDLE_ENFORCE_EQ( + maskdim_vec[mask_ndims - 2], + xdim_vec[x_ndims - 2], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, mask_dim[-2] must be eaqual to x_dim[-2].")); + + PADDLE_ENFORCE_EQ( + maskdim_vec[mask_ndims - 1], + ydim_vec[y_ndims - 1], + phi::errors::PreconditionNotMet( + "The shape of Input(x) and Input(y) is not suitable for matmul " + "opetation, mask_dim[-1] must be eaqual to y_dim[-1].")); + + // InferMeta of SparseCsrTensor 'out' + out->set_dims(mask.dims()); + + phi::Copy(dev_ctx, + mask.non_zero_crows(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_crows()); + phi::Copy(dev_ctx, + mask.non_zero_cols(), + dev_ctx.GetPlace(), + false, + out->mutable_non_zero_cols()); + + DenseTensor* values = out->mutable_non_zero_elements(); + values->Resize(mask.non_zero_elements().dims()); + dev_ctx.template Alloc(values); + + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SDDMM( + false, false, static_cast(1), x, y, static_cast(0), out); +#else + PADDLE_THROW( + phi::errors::Unimplemented(" forward of 'sparse.masked_mm' use " + "cusparseSDDMM, which is supported from " + "CUDA 11.3")); +#endif +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(csr_dense_matmul, + GPU, + ALL_LAYOUT, + phi::sparse::CsrDenseMatmulKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} + +PD_REGISTER_KERNEL(csr_masked_matmul, + GPU, + ALL_LAYOUT, + phi::sparse::CsrMaskedMatmulKernel, + float, + double) {} diff --git a/paddle/phi/kernels/sparse/matmul_grad_kernel.h b/paddle/phi/kernels/sparse/matmul_grad_kernel.h new file mode 100644 index 0000000000000..787691f3515d6 --- /dev/null +++ b/paddle/phi/kernels/sparse/matmul_grad_kernel.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement Backward of " COO @ COO -> COO" +template +void CooCooMatmulGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + const SparseCooTensor& dout, + SparseCooTensor* dx, + SparseCooTensor* dy); + +// TODO(zhouwei25): implement Backward of " COO @ DENSE -> DENSE" +template +void CooDenseMatmulGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + SparseCooTensor* dx, + DenseTensor* dy); + +// TODO(zhouwei25): implement Backward of " CSR @ CSR -> CSR" +template +void CsrCsrMatmulGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + const SparseCsrTensor& dout, + SparseCsrTensor* dx, + SparseCsrTensor* dy); + +/* Backward of "CSR @ DENSE -> DENSE" */ +template +void CsrDenseMatmulGradKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + SparseCsrTensor* dx, + DenseTensor* dy); + +/* Backward of "DENSE @ DENSE * CSR_MASK -> CSR" */ +template +void CsrMaskedMatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/matmul_kernel.h b/paddle/phi/kernels/sparse/matmul_kernel.h new file mode 100644 index 0000000000000..d9093a020c207 --- /dev/null +++ b/paddle/phi/kernels/sparse/matmul_kernel.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +// TODO(zhouwei25): implement " COO @ COO -> COO" +template +void CooCooMatmulKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const SparseCooTensor& y, + SparseCooTensor* out); + +// TODO(zhouwei25): implement " COO @ DENSE -> DENSE" +template +void CooDenseMatmulKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& y, + DenseTensor* out); + +// TODO(zhouwei25): implement " CSR @ CSR -> CSR" +template +void CsrCsrMatmulKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const SparseCsrTensor& y, + SparseCsrTensor* out); + +/* CSR @ DENSE -> DENSE */ +template +void CsrDenseMatmulKernel(const Context& dev_ctx, + const SparseCsrTensor& x, + const DenseTensor& y, + DenseTensor* out); + +/* DENSE @ DENSE * CSR_MASK -> CSR */ +template +void CsrMaskedMatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const SparseCsrTensor& mask, + SparseCsrTensor* out); + +} // namespace sparse +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py new file mode 100644 index 0000000000000..64087ed950743 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.framework import _test_eager_guard + +import numpy as np +import scipy +import scipy.sparse as sp +import unittest +import os +import re + +np.random.seed(2022) + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000, + "paddle is not compiled with CUDA and cuda version need to >= 11.0") +class TestCsrDenseMatmul2D(unittest.TestCase): + # x: csr, y: dense, out: dense + def test_matmul(self): + with _test_eager_guard(): + mask = np.random.rand(10, 12) < 0.2 + np_x = np.random.rand(10, 12) * mask + + np_csr = sp.csr_matrix(np_x) + np_dense = np.random.rand(12, 6) + np_out = np_csr @ np_dense + + np_out_grad = np.ones([10, 6]) + + # dx(csr) = dout(dense) * y'(dense) * mask + np_csr_grad = sp.csr_matrix( + np.matmul(np_out_grad, np_dense.transpose(1, 0)) * mask) + # dy(dense) = x'(csr) * dout(dense) + np_dense_grad = np_csr.transpose() @ np_out_grad + + csr = paddle.to_tensor(np_x, stop_gradient=False).to_sparse_csr() + dense = paddle.to_tensor(np_dense, stop_gradient=False) + out = paddle.incubate.sparse.matmul(csr, dense) + + self.assertTrue(np.allclose(np_out, out.numpy())) + + if get_cuda_version() >= 11030: + out.backward() + self.assertTrue( + np.allclose(np_csr_grad.indptr, + csr.grad.crows().numpy())) + self.assertTrue( + np.allclose(np_csr_grad.indices, + csr.grad.cols().numpy())) + self.assertTrue( + np.allclose(np_csr_grad.data, + csr.grad.values().numpy())) + + self.assertTrue(np.allclose(np_dense_grad, dense.grad.numpy())) + + +@unittest.skipIf( + not paddle.is_compiled_with_cuda() or get_cuda_version() < 11030, + "paddle is not compiled with CUDA and cuda version need to >= 11.3") +class TestCsrMaskedMatmul2D(unittest.TestCase): + # x: dense, y: dense, out: csr + def test_matmul(self): + with _test_eager_guard(): + np_mask = np.random.rand(10, 6) < 0.2 + + np_x = np.random.rand(10, 12) + np_y = np.random.rand(12, 6) + np_out = sp.csr_matrix(np.matmul(np_x, np_y) * np_mask) + + np_out_grad = sp.csr_matrix(np.ones([10, 6]) * np_mask) + # dx(dense) = dout(csr) * y'(dense) + np_x_grad = np_out_grad @ np_y.transpose(1, 0) + # dy(dense) = x'(dense) * dout(csr) -> dy'(dense) = dout'(csr) * x(dense) + np_y_grad = (np_out_grad.transpose() @ np_x).transpose(1, 0) + + x = paddle.to_tensor(np_x, stop_gradient=False) + y = paddle.to_tensor(np_y, stop_gradient=False) + mask = paddle.to_tensor(np.ones([10, 6]) * np_mask).to_sparse_csr() + out = paddle.incubate.sparse.masked_matmul(x, y, mask) + + self.assertTrue(np.allclose(np_out.indptr, out.crows().numpy())) + self.assertTrue(np.allclose(np_out.indices, out.cols().numpy())) + self.assertTrue(np.allclose(np_out.data, out.values().numpy())) + + out.backward() + self.assertTrue(np.allclose(out.is_sparse_csr(), True)) + self.assertTrue(np.allclose(np_x_grad, x.grad.numpy())) + self.assertTrue(np.allclose(np_y_grad, y.grad.numpy())) + + +#TODO(zhouwei25): support unit test of batch 'paddle.sparse.mm/masked_mm' + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index c030cf5bbb9ee..bad6fe7b71375 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -28,10 +28,10 @@ from .tensor import segment_max from .tensor import segment_min from .passes import fuse_resnet_unit_pass -import paddle.incubate.autograd -import paddle.incubate.autotune -import paddle.incubate.sparse +from . import autograd #noqa: F401 +from . import autotune #noqa: F401 +from . import sparse #noqa: F401 from . import nn #noqa: F401 from . import asp #noqa: F401 diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index c499c017a48e8..5fe86995e1d30 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -19,6 +19,9 @@ from .unary import sin from .unary import tanh +from .binary import matmul +from .binary import masked_matmul + from . import nn __all__ = [ @@ -27,4 +30,6 @@ 'sqrt', 'sin', 'tanh', + 'matmul', + 'masked_matmul', ] diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py new file mode 100644 index 0000000000000..f03cd985201fd --- /dev/null +++ b/python/paddle/incubate/sparse/binary.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.common_ops_import import dygraph_only +from paddle import _C_ops + +__all__ = [] + + +@dygraph_only +def matmul(x, y, name=None): + """ + Warning: + This API is only used from ``CUDA 11.0`` . + + Applies matrix multiplication of two Tensors. + + The supported input/output Tensor layout are as follows: + + Note: + x[SparseCsrTensor] @ y[SparseCsrTensor] -> out[SparseCsrTensor] + x[SparseCsrTensor] @ y[DenseTensor] -> out[DenseTensor] + x[SparseCooTensor] @ y[SparseCooTensor] -> out[SparseCooTensor] + x[SparseCooTensor] @ y[DenseTensor] -> out[DenseTensor] + + It supports backward propagation. + + Dimensions `x` and `y` must be >= 2D. Automatic broadcasting of Tensor is not supported. + the shape of `x` should be `[*, M, K]` , and the shape of `y` should be `[*, K, N]` , where `*` + is zero or more batch dimensions. + + Args: + x (Tensor): The input tensor. It can be SparseCooTensor/SparseCsrTensor. The data type can be float32 or float64. + y (Tensor): The input tensor. It can be SparseCooTensor/SparseCsrTensor/DenseTensor. The data type can be float32 or float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Its layout is determined by that of `x` and `y` . + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + paddle.seed(100) + + # csr @ dense -> dense + + with _test_eager_guard(): + crows = [0, 2, 3, 5] + cols = [1, 3, 2, 0, 1] + values = [1., 2., 3., 4., 5.] + dense_shape = [3, 4] + csr = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[1., 2., 3., 4., 5.]) + dense = paddle.randn([4, 3]) + + out = paddle.incubate.sparse.matmul(csr, dense) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[-1.94294846 , -3.33990622 , 0.62359387 ], + # [-4.12815523 , 3.46535444 , -3.27413893 ], + # [-0.15209436 , -19.23207283, -3.35593438 ]]) + + """ + return _C_ops.final_state_sparse_matmul(x, y) + + +@dygraph_only +def masked_matmul(x, y, mask, name=None): + """ + Warning: + This API is only used from ``CUDA 11.3`` . + + Applies matrix multiplication of two Dense Tensors. + + The supported input/output Tensor layout are as follows: + + Note: + x[DenseTensor] @ y[DenseTensor] * mask[SparseCooTensor] -> out[SparseCooTensor] + x[DenseTensor] @ y[DenseTensor] * mask[SparseCsrTensor] -> out[SparseCsrTensor] + + It supports backward propagation. + + Dimensions `x` and `y` must be >= 2D. Automatic broadcasting of Tensor is not supported. + the shape of `x` should be `[*, M, K]` , and the shape of `y` should be `[*, K, N]` , and the shape of `mask` should be `[*, M, N]` , + where `*` is zero or more batch dimensions. + + Args: + x (Tensor): The input tensor. It is DenseTensor. The data type can be float32 or float64. + y (Tensor): The input tensor. It is DenseTensor. The data type can be float32 or float64. + mask (Tensor): The mask tensor, which can be SparseCooTensor/SparseCsrTensor. It specify sparse coordinates. The data type can be float32 or float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: SparseCoo or SparseCsr, which is determined by that of `mask` . + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + paddle.seed(100) + + # dense @ dense * csr_mask -> csr + + with _test_eager_guard(): + crows = [0, 2, 3, 5] + cols = [1, 3, 2, 0, 1] + values = [1., 2., 3., 4., 5.] + dense_shape = [3, 4] + mask = paddle.incubate.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[1., 2., 3., 4., 5.]) + + x = paddle.rand([3, 5]) + y = paddle.rand([5, 4]) + + out = paddle.incubate.sparse.masked_matmul(x, y, mask) + # Tensor(shape=[3, 4], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 3, 5], + # cols=[1, 3, 2, 0, 1], + # values=[0.98986477, 0.97800624, 1.14591956, 0.68561077, 0.94714981]) + + """ + return _C_ops.final_state_sparse_masked_matmul(x, y, mask) diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 5d1dc55f0638d..84c6d2a16af43 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -88,6 +88,34 @@ layout : x backward : values_grad +- api: full_like + args : (Tensor x, Scalar value, DataType dtype=DataType::UNDEFINED) + output : Tensor(out) + kernel : + func : coo_full_like{sparse_coo -> sparse_coo}, + csr_full_like{sparse_csr -> sparse_csr} + layout : x + data_type : dtype + +- api: masked_matmul + args : (Tensor x, Tensor y, Tensor mask) + output : Tensor(out) + kernel : + func : csr_masked_matmul{dense, dense, sparse_csr -> sparse_csr} + layout : x + backward: masked_matmul_grad + +- api: matmul + args : (Tensor x, Tensor y) + output : Tensor(out) + kernel : + func : csr_dense_matmul{sparse_csr, dense -> dense}, + csr_csr_matmul{sparse_csr, sparse_csr -> sparse_csr}, + coo_dense_matmul{sparse_coo, dense -> dense}, + coo_coo_matmul{sparse_coo, sparse_coo -> sparse_coo} + layout : x + backward: matmul_grad + - api: maxpool args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) output : Tensor(out), Tensor(rulebook) diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index eb7114cbdd2c9..5d9874dff29ec 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -25,6 +25,20 @@ output : Tensor(x_grad) invoke : to_dense_impl(out_grad) +- backward_api : masked_matmul_grad + forward : masked_matmul(Tensor x, Tensor y, Tensor mask) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : csr_masked_matmul_grad{dense, dense, sparse_csr -> dense, dense} + +- backward_api : matmul_grad + forward : matmul(Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + kernel : + func : csr_dense_matmul_grad{sparse_csr, dense, dense -> sparse_csr, dense} + - backward_api : relu_grad forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad)