From cf2db1fb13214eb178fc7a68980d38d846d0ca36 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Fri, 29 Jul 2022 14:02:38 +0000 Subject: [PATCH] [Sparse] optimize sparse attention --- paddle/fluid/platform/dynload/cusparse.h | 1 + paddle/phi/backends/dynload/cusparse.h | 1 + .../funcs/sparse/sparse_blas_impl.cu.h | 13 +- .../sparse/gpu/fused_attention_grad_kernel.cu | 23 ++-- .../sparse/gpu/fused_attention_kernel.cu | 115 +++++------------- .../test_sparse_fused_attention_op.py | 2 +- 6 files changed, 54 insertions(+), 101 deletions(-) diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 480245fec253d..f026197490d6a 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(cusparseSpMM_preprocess); \ __macro(cusparseSDDMM_bufferSize); \ __macro(cusparseSDDMM_preprocess); \ __macro(cusparseSDDMM); diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index 45a466b3801ff..2f4ec151b1ece 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(cusparseSpMM_preprocess); \ __macro(cusparseSDDMM_bufferSize); \ __macro(cusparseSDDMM_preprocess); \ __macro(cusparseSDDMM); diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 9f7be26857bdb..0458f0d83ed1a 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) { } } +inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCsrTensor& x) { + // TODO(zhouwei): will change to 'CUSPARSE_SPMM_CSR_ALG2' when support batch + return CUSPARSE_SPMM_CSR_ALG2; +} + +inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCooTensor& x) { + return CUSPARSE_SPMM_ALG_DEFAULT; +} + /************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/ template @@ -324,7 +333,7 @@ void SparseBlas::SPMM(bool transa, &beta, out_descriptor.descriptor(), gpu_type, - CUSPARSE_SPMM_ALG_DEFAULT, + GetSpMMAlgorithm(mat_a), &buffer_size); }); @@ -341,7 +350,7 @@ void SparseBlas::SPMM(bool transa, &beta, out_descriptor.descriptor(), gpu_type, - CUSPARSE_SPMM_ALG_DEFAULT, + GetSpMMAlgorithm(mat_a), tmp_buffer_ptr); }); } diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu index 70203836d4412..5be45605983f6 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu @@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, int row_nnz = static_cast(out_crows[crow_idx + 1] - out_crows[crow_idx]); if (row_nnz == 0) return; - int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; - T mul_result = 0; - for (int i = 0; i < kIteration; ++i) { - int idx = threadIdx.x + i * WARP_SIZE; - if (idx >= row_nnz) break; - - mul_result += out_values[row_first + idx] * dout_values[row_first + idx]; + T mul = 0; + for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { + mul += out_values[row_first + idx] * dout_values[row_first + idx]; } - T sum = phi::funcs::warpReduceSum(mul_result, 0xFFFFFFFF); - - for (int i = 0; i < kIteration; ++i) { - int idx = threadIdx.x + i * WARP_SIZE; - if (idx >= row_nnz) break; + T mul_sum = phi::funcs::warpReduceSum(mul, 0xFFFFFFFF); - dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) * + for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { + dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) * out_values[row_first + idx] / scale; } } @@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx, int N = q_dim[q_rank - 1]; int batch_nnz = softmax.nnz() / batch_num; - dim3 grid((total_row_num + 3) / 4); - dim3 block(WARP_SIZE, 4); + dim3 grid((total_row_num + 7) / 8); + dim3 block(WARP_SIZE, 8); AttnSoftmaxGpuGradKernel<<>>( softmax.non_zero_crows().data(), diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu index b1e30f3b654a4..8761319ee8d63 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu @@ -26,30 +26,7 @@ limitations under the License. */ namespace phi { namespace sparse { -#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \ - case size: { \ - constexpr int HINT = size; \ - __VA_ARGS__(); \ - break; \ - } - -#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \ - [&] { \ - const auto& __size__ = SIZE; \ - switch (__size__) { \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \ - PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for columns>512 "); \ - } \ - }() - -template +template __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, const int64_t* x_cols, const T* x_values, @@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, T* out_values, int M, int total_row_num, - float scale, int num_heads, int batch_nnz) { // out = exp(x-x_max) / sum(exp(x-x_max)) @@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, int row_nnz = static_cast(x_crows[crow_idx + 1] - x_crows[crow_idx]); if (row_nnz == 0) return; - T buffer[BufferSize] = {0}; - int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; - T max_val = -std::numeric_limits::infinity(); - for (int i = 0; i < kIteration; ++i) { + for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { bool mask = false; - int idx = threadIdx.x + i * WARP_SIZE; - if (idx >= row_nnz) break; - int col_idx = static_cast(x_cols[row_first + idx]); - if (kp_mask != nullptr && kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) { mask = true; @@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, } if (!mask) { - buffer[i] = x_values[row_first + idx] / scale; - if (buffer[i] > max_val) { - max_val = buffer[i]; + T val = x_values[row_first + idx]; + if (val > max_val) { + max_val = val; } + out_values[row_first + idx] = val; + } else { + // Note corner case: when all elements of the row are masked, result + // may be wrong because of exp('-inf' - '-inf'), just ignore now. + out_values[row_first + idx] = -std::numeric_limits::infinity(); } } T row_max_val = phi::funcs::warpReduceMax(max_val, 0xFFFFFFFF); - auto functor = phi::funcs::CudaExpFunctor(); T exp_sum = 0; - for (int i = 0; i < kIteration; ++i) { - int idx = threadIdx.x + i * WARP_SIZE; - if (idx >= row_nnz) break; - - if (buffer[i]) { - T exp = functor(buffer[i] - row_max_val); - exp_sum += exp; - buffer[i] = exp; - } + for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { + auto functor = phi::funcs::CudaExpFunctor(); + T exp = functor(out_values[row_first + idx] - row_max_val); + exp_sum += exp; + out_values[row_first + idx] = exp; } T row_exp_sum = phi::funcs::warpReduceSum(exp_sum, 0xFFFFFFFF); - for (int i = 0; i < kIteration; ++i) { - int idx = threadIdx.x + i * WARP_SIZE; - if (idx >= row_nnz) break; - - if (buffer[i]) { - out_values[row_first + idx] = buffer[i] / row_exp_sum; - } else { - out_values[row_first + idx] = static_cast(0); - } + for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { + out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum; } } @@ -219,49 +181,36 @@ void FusedAttentionCsrKernel( "shape of 'attn_mask' must be [seq_len, seq_len]")); } - /* Step1: SDD Matmul, reuse */ + /* Step1: SDD Matmul, reuse matmul */ SparseCsrTensor sdd_result; EmptyLikeCsrKernel(dev_ctx, sparse_mask, &sdd_result); auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); sparse_blas.SDDMM(false, true, - static_cast(1), + static_cast(1 / std::sqrt(N)), query, key, static_cast(0), &sdd_result); - /* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */ EmptyLikeCsrKernel(dev_ctx, sdd_result, softmax); - int buffer_size; - if (M < 128) { - buffer_size = (M + 32 - 1) / 32; - } else { - buffer_size = ((M + 128 - 1) / 128) * 4; - } - - dim3 grid((total_row_num + 3) / 4); - dim3 block(WARP_SIZE, 4); + dim3 grid((total_row_num + 7) / 8); + dim3 block(WARP_SIZE, 8); int batch_nnz = sdd_result.nnz() / batch_num; + AttnSoftmaxGpuKernel<<>>( + sdd_result.non_zero_crows().data(), + sdd_result.non_zero_cols().data(), + sdd_result.non_zero_elements().data(), + kp_mask_ptr ? kp_mask_ptr->data() : nullptr, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + softmax->mutable_non_zero_elements()->data(), + M, + total_row_num, + q_dim[1], + batch_nnz); - VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] { - AttnSoftmaxGpuKernel<<>>( - sdd_result.non_zero_crows().data(), - sdd_result.non_zero_cols().data(), - sdd_result.non_zero_elements().data(), - kp_mask_ptr ? kp_mask_ptr->data() : nullptr, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, - softmax->mutable_non_zero_elements()->data(), - M, - total_row_num, - std::sqrt(N), - q_dim[1], - batch_nnz); - }); - - /* Step3: DSD Matmul, reuse */ softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]})); MatmulCsrDenseKernel(dev_ctx, *softmax, value, out); #else diff --git a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py index 0383247886ff2..58a3c1ad20113 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py @@ -37,7 +37,7 @@ def get_cuda_version(): @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11070, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "core is not compiled with CUDA and cuda version need larger than or equal to 11.7" ) class TestSparseAttentionAPI1(unittest.TestCase):