Skip to content

Commit

Permalink
fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Jun 13, 2022
1 parent abe4db0 commit f61234d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
12 changes: 1 addition & 11 deletions paddle/phi/kernels/sparse/CMakeLists.txt
Expand Up @@ -10,14 +10,4 @@ set(SPARSE_KERNEL_DEPS
math_function
custom_kernel
copy_kernel)

set(MANUAL_BUILD_KERNELS sparse_mm_kernel sparse_mm_grad_kernel)

register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${SPARSE_KERNEL_DEPS}
SUB_DIR "sparse")

message("===SPARSE======CUDA_VERSION:==${CUDA_VERSION}============")
if((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.0))
kernel_library(sparse_mm_kernel DEPS ${SPARSE_KERNEL_DEPS})
kernel_library(sparse_mm_grad_kernel DEPS ${SPARSE_KERNEL_DEPS})
endif()
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse")
2 changes: 2 additions & 0 deletions paddle/phi/kernels/sparse/gpu/sparse_mm_grad_kernel.cu
Expand Up @@ -83,6 +83,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
const SparseCsrTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
#if CUDA_VERSION >= 11000
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);

// dx{Dense} = dout{SparseCsr} * y'{Dense}
Expand Down Expand Up @@ -123,6 +124,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
std::swap(axis[y_ndim - 1], axis[y_ndim - 2]);
TransposeKernel<T, Context>(dev_ctx, trans_dy, axis, dy);
}
#endif
}

} // namespace sparse
Expand Down
13 changes: 10 additions & 3 deletions paddle/phi/kernels/sparse/gpu/sparse_mm_kernel.cu
Expand Up @@ -34,6 +34,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
const DenseTensor& y,
DenseTensor* out) {
#if CUDA_VERSION >= 11000
std::vector<int64_t> xdim_vec = phi::vectorize(x.dims());
std::vector<int64_t> ydim_vec = phi::vectorize(y.dims());
auto x_ndims = xdim_vec.size();
Expand Down Expand Up @@ -80,6 +81,11 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.DSDMM(
false, false, static_cast<T>(1), x, y, static_cast<T>(0), out);
#else
PADDLE_THROW(
phi::errors::Unimplemented(" forward of 'sparse.mm' use cusparseSpMM, "
"which is supported from CUDA 11.0"));
#endif
}

template <typename T, typename Context>
Expand Down Expand Up @@ -174,9 +180,10 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx,
sparse_blas.SDDMM(
false, false, static_cast<T>(1), x, y, static_cast<T>(0), out);
#else
PADDLE_THROW(phi::errors::Unimplemented(
" forward of 'sparse.masked_mm' use cusparseSDDMM, Only support it from "
"CUDA 11.3"));
PADDLE_THROW(
phi::errors::Unimplemented(" forward of 'sparse.masked_mm' use "
"cusparseSDDMM, which is supported from "
"CUDA 11.3"));
#endif
}

Expand Down

0 comments on commit f61234d

Please sign in to comment.