Skip to content

Commit

Permalink
[Sparse] support batch compute of SparseTensor matmul/masked_matmul/s…
Browse files Browse the repository at this point in the history
…oftmax (#43703)
  • Loading branch information
zhwesky2010 committed Jun 24, 2022
1 parent fa9586a commit eec4e03
Show file tree
Hide file tree
Showing 16 changed files with 457 additions and 232 deletions.
43 changes: 25 additions & 18 deletions paddle/fluid/platform/dynload/cusparse.h
Expand Up @@ -31,24 +31,22 @@ namespace dynload {
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat);

CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
Expand All @@ -62,8 +60,17 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
CUSPARSE_ROUTINE_EACH_R2(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif

#if CUDA_VERSION >= 11070
#define CUSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCooSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);

CUSPARSE_ROUTINE_EACH_R3(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif

#endif // PADDLE_WITH_CUDA

#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
} // namespace dynload
} // namespace platform
Expand Down
43 changes: 25 additions & 18 deletions paddle/phi/backends/dynload/cusparse.h
Expand Up @@ -43,24 +43,22 @@ extern void *cusparse_dso_handle;
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat);

CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
Expand All @@ -74,8 +72,17 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif

#if CUDA_VERSION >= 11070
#define CUSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCooSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);

CUSPARSE_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif

#endif // PADDLE_WITH_CUDA

#undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
} // namespace dynload
} // namespace phi
34 changes: 12 additions & 22 deletions paddle/phi/kernels/funcs/sparse/sparse_blas.h
Expand Up @@ -28,33 +28,23 @@ class SparseBlas {
public:
explicit SparseBlas(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {}

// TODO(zhouwei25): implement "COO @ DENSE -> DENSE" of DSDMM
template <typename T>
void DSDMM(bool transa,
bool transb,
T alpha,
const phi::SparseCooTensor& mat_a,
const phi::DenseTensor& mat_b,
T beta,
phi::DenseTensor* mat_c) const;

template <typename T>
void DSDMM(bool transa,
bool transb,
T alpha,
const phi::SparseCsrTensor& mat_a,
const phi::DenseTensor& mat_b,
T beta,
phi::DenseTensor* mat_c) const;
template <typename T, typename TensorType>
void SPMM(bool transa,
bool transb,
T alpha,
const TensorType& mat_a,
const phi::DenseTensor& mat_b,
T beta,
phi::DenseTensor* mat_out) const;

template <typename T>
template <typename T, typename TensorType>
void SDDMM(bool transa,
bool transb,
T alpha,
const phi::DenseTensor& mat_a,
const phi::DenseTensor& mat_b,
T beta,
phi::SparseCsrTensor* mat_c) const;
TensorType* mat_out) const;

private:
const DeviceContext& dev_ctx_;
Expand All @@ -66,8 +56,8 @@ class SparseBlasT : private SparseBlas<DeviceContext> {
using SparseBlas<DeviceContext>::SparseBlas;

template <typename... ARGS>
void DSDMM(ARGS... args) const {
Base()->template DSDMM<T>(args...);
void SPMM(ARGS... args) const {
Base()->template SPMM<T>(args...);
}

template <typename... ARGS>
Expand Down

0 comments on commit eec4e03

Please sign in to comment.