Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] add Fused Attention kernel and API for SparseCsrTensor #43966

Merged
merged 2 commits into from Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/sparse_api.yaml
Expand Up @@ -141,6 +141,15 @@
layout : x
data_type : dtype

- api: fused_attention
args : (Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask)
output : Tensor(out), Tensor(softmax)
kernel :
func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr}
layout : sparse_mask
intermediate : softmax
backward: fused_attention_grad

- api: masked_matmul
args : (Tensor x, Tensor y, Tensor mask)
output : Tensor(out)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/api/yaml/sparse_bw_api.yaml
Expand Up @@ -127,3 +127,10 @@
output : Tensor(x_grad)
kernel :
func : coo_values_grad{sparse_coo, dense-> sparse_coo}

- backward_api: fused_attention_grad
forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Expand Up @@ -232,7 +232,7 @@ class CuSparseDnMatDescriptor {

PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N);
if (batch_size > 1) {
#if CUDA_VERSION >= 11070
#if CUDA_VERSION >= 11030
dev_ctx_.CusparseCall([&](cusparseHandle_t handle) {
phi::dynload::cusparseDnMatSetStridedBatch(
descriptor_, batch_size, M * N);
Expand Down
38 changes: 38 additions & 0 deletions paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc
@@ -0,0 +1,38 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue) {
PD_THROW(
"Only support 'fused_attention' CPU backward kernel of SparseTensor now");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

} // namespace sparse
} // namespace phi
37 changes: 37 additions & 0 deletions paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
@@ -0,0 +1,37 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/sparse/fused_attention_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW("Only support 'fused_attention' CPU kernel of SparseTensor now");
}

} // namespace sparse
} // namespace phi
35 changes: 35 additions & 0 deletions paddle/phi/kernels/sparse/fused_attention_grad_kernel.h
@@ -0,0 +1,35 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue);

} // namespace sparse
} // namespace phi
35 changes: 35 additions & 0 deletions paddle/phi/kernels/sparse/fused_attention_kernel.h
@@ -0,0 +1,35 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"

namespace phi {
namespace sparse {

template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax);

} // namespace sparse
} // namespace phi
150 changes: 150 additions & 0 deletions paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
@@ -0,0 +1,150 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h"

namespace phi {
namespace sparse {

template <typename T>
__global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
const T* out_values,
const T* dout_values,
T* dx_values,
int M,
int total_row_num,
float scale) {
// dx = (dout - sum(dout * out)) * out
int row = blockIdx.x * blockDim.y + threadIdx.y;
int non_zero_idx = threadIdx.x;
if (row >= total_row_num) return;
int cur_batch = row / M;
int crow_idx = cur_batch * (M + 1) + (row % M);
int cur_batch_offset = 0;
for (int i = 1; i < cur_batch + 1; ++i) {
cur_batch_offset += static_cast<int>(out_crows[i * (M + 1) - 1]);
}
int row_first = cur_batch_offset + static_cast<int>(out_crows[crow_idx]);
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * WARP_SIZE;
if (idx >= row_nnz) break;

dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
out_values[row_first + idx] / scale;
}
}

template <typename T, typename Context>
void FusedAttentionCsrGradKernel(const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& softmax,
const DenseTensor& dout,
DenseTensor* dquery,
DenseTensor* dkey,
DenseTensor* dvalue) {
#if CUDA_VERSION >= 11070
/* Step1: Forward: softmax{CSR} * value{Dense} -> out{Dense}, reuse */
SparseCsrTensor dsoftmax;
CsrDenseMatmulGradKernel<T, Context>(
dev_ctx, softmax, value, dout, &dsoftmax, dvalue);

/* Step2: Calculate grad of sdd_result, manualy not reuse */
SparseCsrTensor d_sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, dsoftmax, &d_sdd_result);
auto q_dim = query.dims();
auto q_rank = q_dim.size();

int total_row_num = 1;
for (int i = 0; i < q_rank - 1; ++i) {
total_row_num *= q_dim[i];
}
int M = q_dim[q_rank - 2];
int N = q_dim[q_rank - 1];

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);

AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
softmax.non_zero_crows().data<int64_t>(),
softmax.non_zero_elements().data<T>(),
dsoftmax.mutable_non_zero_elements()->data<T>(),
d_sdd_result.mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N));

/* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
// dquery{Dense} = d_sdd_result{SparseCsr} * key{Dense} //
dquery->Resize(query.dims());
dev_ctx.template Alloc<T>(dquery);
sparse_blas.SPMM(false,
false,
static_cast<T>(1.f),
d_sdd_result,
key,
static_cast<T>(0.f),
dquery);

// dkey{Dense} = d_sdd_result'{SparseCsr} * query{Dense} //
dkey->Resize(key.dims());
dev_ctx.template Alloc<T>(dkey);
sparse_blas.SPMM(true,
false,
static_cast<T>(1.f),
d_sdd_result,
query,
static_cast<T>(0.f),
dkey);
#else
PADDLE_THROW(
phi::errors::Unimplemented("backward of 'sparse.nn.functional.attention' "
"use 'cusparseCsrSetStridedBatch', which is "
"completed supported from CUDA 11.7"));
#endif
}

} // namespace sparse
} // namespace phi

PD_REGISTER_KERNEL(fused_attention_csr_grad,
GPU,
ALL_LAYOUT,
phi::sparse::FusedAttentionCsrGradKernel,
float,
double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
}