Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Jul 5, 2022
1 parent 7ef44da commit 849deff
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 31 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Expand Up @@ -232,7 +232,7 @@ class CuSparseDnMatDescriptor {

PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N);
if (batch_size > 1) {
#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11070
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDnMatSetStridedBatch(
descriptor_, batch_size, M * N);
Expand Down
Expand Up @@ -31,7 +31,7 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
DenseTensor* dkey,
DenseTensor* dvalue) {
PD_THROW(
"Only support 'fused_attention' CPU backward kernel of SparseTensor now");
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}

} // namespace sparse
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
Expand Up @@ -30,7 +30,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW("Only support 'fused_attention' CPU kernel of SparseTensor now");
PD_THROW(
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
}

} // namespace sparse
Expand Down
23 changes: 13 additions & 10 deletions paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
Expand Up @@ -31,33 +31,30 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
T* dx_values,
int M,
int total_row_num,
float scale) {
float scale,
int batch_nnz) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= total_row_num) return;

int cur_batch = row / M;
int crow_idx = cur_batch * (M + 1) + (row % M);
int cur_batch_offset = 0;
for (int i = 1; i < cur_batch + 1; ++i) {
cur_batch_offset += static_cast<int>(out_crows[i * (M + 1) - 1]);
}
int row_first = cur_batch_offset + static_cast<int>(out_crows[crow_idx]);
int row_first = cur_batch * batch_nnz + static_cast<int>(out_crows[crow_idx]);
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
Expand Down Expand Up @@ -88,11 +85,16 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
auto q_rank = q_dim.size();

int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num;

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
Expand All @@ -104,7 +106,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
d_sdd_result.mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N));
std::sqrt(N),
batch_nnz);

/* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
Expand Down
47 changes: 36 additions & 11 deletions paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
Expand Up @@ -59,19 +59,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int M,
int total_row_num,
float scale,
int num_heads) {
int num_heads,
int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max))
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= total_row_num) return;

int cur_batch = row / M;
int cur_row = row % M;
int crow_idx = cur_batch * (M + 1) + cur_row;
int cur_batch_offset = 0;
for (int i = 1; i < cur_batch + 1; ++i) {
cur_batch_offset += static_cast<int>(x_crows[i * (M + 1) - 1]);
}
int row_first = cur_batch_offset + static_cast<int>(x_crows[crow_idx]);
int row_first = cur_batch * batch_nnz + static_cast<int>(x_crows[crow_idx]);
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return;

Expand All @@ -81,7 +78,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
bool mask = false;
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

int col_idx = static_cast<int>(x_cols[row_first + idx]);
Expand All @@ -106,7 +103,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
Expand All @@ -118,7 +115,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
Expand All @@ -145,8 +142,12 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
auto q_rank = q_dim.size();

int total_row_num = 1;
int batch_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
if (i < q_rank - 2) {
batch_num *= q_dim[i];
}
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];
Expand All @@ -161,6 +162,27 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
4,
phi::errors::InvalidArgument(" 'value' must be 4D Tensor"));

PADDLE_ENFORCE_EQ(
sparse_mask.dims().size(),
3,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[0],
q_dim[0] * q_dim[1],
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[1],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(
sparse_mask.dims()[2],
M,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));

PADDLE_ENFORCE_EQ(
key_padding_mask.dims().size(),
2,
Expand Down Expand Up @@ -215,6 +237,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);

int batch_nnz = sdd_result.nnz() / batch_num;

VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
Expand All @@ -226,7 +250,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
M,
total_row_num,
std::sqrt(N),
q_dim[1]);
q_dim[1],
batch_nnz);
});

/* Step3: DSD Matmul, reuse */
Expand Down
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_cuda_version():


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
)
class TestSparseAttentionAPI1(unittest.TestCase):
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/incubate/sparse/nn/functional/transformer.py
Expand Up @@ -37,19 +37,18 @@ def attention(query,
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
result = softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The shape of the three parameters are:
The dimensions of these three parameters are: [batch_size, num_heads, seq_len, head_dim].
The shape of the three parameters are: `[batch_size, num_heads, seq_len, head_dim]`, and
``d`` represents ``head_dim`` .
Args:
query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64.
key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64.
value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64.
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. shape of `crows` is
[batch_size, num_heads, seq_len + 1], shape of `cols` is [batch_size, num_heads, nnz].
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
Expand Down

0 comments on commit 849deff

Please sign in to comment.