From 849deff98b7acf707241825c554549b89e42fe64 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Mon, 4 Jul 2022 13:35:15 +0000 Subject: [PATCH] fix comment --- .../funcs/sparse/sparse_blas_impl.cu.h | 2 +- .../sparse/cpu/fused_attention_grad_kernel.cc | 2 +- .../sparse/cpu/fused_attention_kernel.cc | 3 +- .../sparse/gpu/fused_attention_grad_kernel.cu | 23 +++++---- .../sparse/gpu/fused_attention_kernel.cu | 47 ++++++++++++++----- .../test_sparse_fused_attention_op.py | 4 +- .../sparse/nn/functional/transformer.py | 9 ++-- 7 files changed, 59 insertions(+), 31 deletions(-) diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index d817290882916..3d92674c92d6e 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -232,7 +232,7 @@ class CuSparseDnMatDescriptor { PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N); if (batch_size > 1) { -#if CUDA_VERSION >= 11030 +#if CUDA_VERSION >= 11070 dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDnMatSetStridedBatch( descriptor_, batch_size, M * N); diff --git a/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc index d83951594f364..416b715a9a6a2 100644 --- a/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc @@ -31,7 +31,7 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx, DenseTensor* dkey, DenseTensor* dvalue) { PD_THROW( - "Only support 'fused_attention' CPU backward kernel of SparseTensor now"); + "Not support CPU kernel of 'sparse.nn.functional.fused_attention' now"); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc index bff25da8a5a37..6c652c6a8c4d6 100644 --- a/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc @@ -30,7 +30,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, const DenseTensor& attn_mask, DenseTensor* out, SparseCsrTensor* softmax) { - PD_THROW("Only support 'fused_attention' CPU kernel of SparseTensor now"); + PD_THROW( + "Not support CPU kernel of 'sparse.nn.functional.fused_attention' now"); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu index b22965f663eb5..4d31ad96cdd3b 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu @@ -31,25 +31,22 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, T* dx_values, int M, int total_row_num, - float scale) { + float scale, + int batch_nnz) { // dx = (dout - sum(dout * out)) * out int row = blockIdx.x * blockDim.y + threadIdx.y; - int non_zero_idx = threadIdx.x; if (row >= total_row_num) return; + int cur_batch = row / M; int crow_idx = cur_batch * (M + 1) + (row % M); - int cur_batch_offset = 0; - for (int i = 1; i < cur_batch + 1; ++i) { - cur_batch_offset += static_cast(out_crows[i * (M + 1) - 1]); - } - int row_first = cur_batch_offset + static_cast(out_crows[crow_idx]); + int row_first = cur_batch * batch_nnz + static_cast(out_crows[crow_idx]); int row_nnz = static_cast(out_crows[crow_idx + 1] - out_crows[crow_idx]); if (row_nnz == 0) return; int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; T mul_result = 0; for (int i = 0; i < kIteration; ++i) { - int idx = non_zero_idx + i * WARP_SIZE; + int idx = threadIdx.x + i * WARP_SIZE; if (idx >= row_nnz) break; mul_result += out_values[row_first + idx] * dout_values[row_first + idx]; @@ -57,7 +54,7 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, T sum = phi::funcs::warpReduceSum(mul_result, 0xFFFFFFFF); for (int i = 0; i < kIteration; ++i) { - int idx = non_zero_idx + i * WARP_SIZE; + int idx = threadIdx.x + i * WARP_SIZE; if (idx >= row_nnz) break; dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) * @@ -88,11 +85,16 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx, auto q_rank = q_dim.size(); int total_row_num = 1; + int batch_num = 1; for (int i = 0; i < q_rank - 1; ++i) { total_row_num *= q_dim[i]; + if (i < q_rank - 2) { + batch_num *= q_dim[i]; + } } int M = q_dim[q_rank - 2]; int N = q_dim[q_rank - 1]; + int batch_nnz = softmax.nnz() / batch_num; dim3 grid((total_row_num + 3) / 4); dim3 block(WARP_SIZE, 4); @@ -104,7 +106,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx, d_sdd_result.mutable_non_zero_elements()->data(), M, total_row_num, - std::sqrt(N)); + std::sqrt(N), + batch_nnz); /* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */ auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu index 8af0505fc3117..9a7e55d2d6210 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu @@ -59,19 +59,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, int M, int total_row_num, float scale, - int num_heads) { + int num_heads, + int batch_nnz) { // out = exp(x-x_max) / sum(exp(x-x_max)) int row = blockIdx.x * blockDim.y + threadIdx.y; - int non_zero_idx = threadIdx.x; if (row >= total_row_num) return; + int cur_batch = row / M; int cur_row = row % M; int crow_idx = cur_batch * (M + 1) + cur_row; - int cur_batch_offset = 0; - for (int i = 1; i < cur_batch + 1; ++i) { - cur_batch_offset += static_cast(x_crows[i * (M + 1) - 1]); - } - int row_first = cur_batch_offset + static_cast(x_crows[crow_idx]); + int row_first = cur_batch * batch_nnz + static_cast(x_crows[crow_idx]); int row_nnz = static_cast(x_crows[crow_idx + 1] - x_crows[crow_idx]); if (row_nnz == 0) return; @@ -81,7 +78,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, T max_val = -std::numeric_limits::infinity(); for (int i = 0; i < kIteration; ++i) { bool mask = false; - int idx = non_zero_idx + i * WARP_SIZE; + int idx = threadIdx.x + i * WARP_SIZE; if (idx >= row_nnz) break; int col_idx = static_cast(x_cols[row_first + idx]); @@ -106,7 +103,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, auto functor = phi::funcs::CudaExpFunctor(); T exp_sum = 0; for (int i = 0; i < kIteration; ++i) { - int idx = non_zero_idx + i * WARP_SIZE; + int idx = threadIdx.x + i * WARP_SIZE; if (idx >= row_nnz) break; if (buffer[i]) { @@ -118,7 +115,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, T row_exp_sum = phi::funcs::warpReduceSum(exp_sum, 0xFFFFFFFF); for (int i = 0; i < kIteration; ++i) { - int idx = non_zero_idx + i * WARP_SIZE; + int idx = threadIdx.x + i * WARP_SIZE; if (idx >= row_nnz) break; if (buffer[i]) { @@ -145,8 +142,12 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, auto q_rank = q_dim.size(); int total_row_num = 1; + int batch_num = 1; for (int i = 0; i < q_rank - 1; ++i) { total_row_num *= q_dim[i]; + if (i < q_rank - 2) { + batch_num *= q_dim[i]; + } } int M = q_dim[q_rank - 2]; int N = q_dim[q_rank - 1]; @@ -161,6 +162,27 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, 4, phi::errors::InvalidArgument(" 'value' must be 4D Tensor")); + PADDLE_ENFORCE_EQ( + sparse_mask.dims().size(), + 3, + phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " + "[batch_size*num_heads, seq_len, seq_len]")); + PADDLE_ENFORCE_EQ( + sparse_mask.dims()[0], + q_dim[0] * q_dim[1], + phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " + "[batch_size*num_heads, seq_len, seq_len]")); + PADDLE_ENFORCE_EQ( + sparse_mask.dims()[1], + M, + phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " + "[batch_size*num_heads, seq_len, seq_len]")); + PADDLE_ENFORCE_EQ( + sparse_mask.dims()[2], + M, + phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " + "[batch_size*num_heads, seq_len, seq_len]")); + PADDLE_ENFORCE_EQ( key_padding_mask.dims().size(), 2, @@ -215,6 +237,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, dim3 grid((total_row_num + 3) / 4); dim3 block(WARP_SIZE, 4); + int batch_nnz = sdd_result.nnz() / batch_num; + VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] { AttnSoftmaxGpuKernel<<>>( sdd_result.non_zero_crows().data(), @@ -226,7 +250,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, M, total_row_num, std::sqrt(N), - q_dim[1]); + q_dim[1], + batch_nnz); }); /* Step3: DSD Matmul, reuse */ diff --git a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py index a25ece822bb65..e34f890cc53d4 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ def get_cuda_version(): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + not core.is_compiled_with_cuda() or get_cuda_version() < 11070, "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" ) class TestSparseAttentionAPI1(unittest.TestCase): diff --git a/python/paddle/incubate/sparse/nn/functional/transformer.py b/python/paddle/incubate/sparse/nn/functional/transformer.py index 1686487ef7dc1..3429eecccd758 100644 --- a/python/paddle/incubate/sparse/nn/functional/transformer.py +++ b/python/paddle/incubate/sparse/nn/functional/transformer.py @@ -37,19 +37,18 @@ def attention(query, .. math:: - result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + result = softmax(\frac{ Q * K^T }{\sqrt{d}}) * V where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. - The shape of the three parameters are: - The dimensions of these three parameters are: [batch_size, num_heads, seq_len, head_dim]. + The shape of the three parameters are: `[batch_size, num_heads, seq_len, head_dim]`, and ``d`` represents ``head_dim`` . Args: query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64. key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64. value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64. - sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. shape of `crows` is - [batch_size, num_heads, seq_len + 1], shape of `cols` is [batch_size, num_heads, nnz]. + sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape + is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same. dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64. key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module. 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.