Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] optimize sparse attention op kernel #44743

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/cusparse.h
Expand Up @@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)

#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/cusparse.h
Expand Up @@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)

#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
Expand Down
13 changes: 11 additions & 2 deletions paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Expand Up @@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
}
}

inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCsrTensor& x) {
// TODO(zhouwei): will change to 'CUSPARSE_SPMM_CSR_ALG2' when support batch
return CUSPARSE_SPMM_CSR_ALG2;
}

inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCooTensor& x) {
return CUSPARSE_SPMM_ALG_DEFAULT;
}

/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/

template <typename T, typename IntT>
Expand Down Expand Up @@ -324,7 +333,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT,
GetSpMMAlgorithm(mat_a),
&buffer_size);
});

Expand All @@ -341,7 +350,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT,
GetSpMMAlgorithm(mat_a),
tmp_buffer_ptr);
});
}
Expand Down
23 changes: 8 additions & 15 deletions paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
Expand Up @@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
T mul = 0;
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
mul += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
T mul_sum = phi::funcs::warpReduceSum<T>(mul, 0xFFFFFFFF);

dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) *
out_values[row_first + idx] / scale;
}
}
Expand Down Expand Up @@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num;

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
dim3 grid((total_row_num + 7) / 8);
dim3 block(WARP_SIZE, 8);

AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
softmax.non_zero_crows().data<int64_t>(),
Expand Down
115 changes: 32 additions & 83 deletions paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
Expand Up @@ -26,30 +26,7 @@ limitations under the License. */
namespace phi {
namespace sparse {

#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \
case size: { \
constexpr int HINT = size; \
__VA_ARGS__(); \
break; \
}

#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \
[&] { \
const auto& __size__ = SIZE; \
switch (__size__) { \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for columns>512 "); \
} \
}()

template <typename T, int BufferSize>
template <typename T>
__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
const int64_t* x_cols,
const T* x_values,
Expand All @@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T* out_values,
int M,
int total_row_num,
float scale,
int num_heads,
int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max))
Expand All @@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return;

T buffer[BufferSize] = {0};
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;

T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
bool mask = false;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

int col_idx = static_cast<int>(x_cols[row_first + idx]);

if (kp_mask != nullptr &&
kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) {
mask = true;
Expand All @@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
}

if (!mask) {
buffer[i] = x_values[row_first + idx] / scale;
if (buffer[i] > max_val) {
max_val = buffer[i];
T val = x_values[row_first + idx];
if (val > max_val) {
max_val = val;
}
out_values[row_first + idx] = val;
} else {
// Note corner case: when all elements of the row are masked, result
// may be wrong because of exp('-inf' - '-inf'), just ignore now.
out_values[row_first + idx] = -std::numeric_limits<T>::infinity();
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);

auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
T exp = functor(buffer[i] - row_max_val);
exp_sum += exp;
buffer[i] = exp;
}
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp = functor(out_values[row_first + idx] - row_max_val);
exp_sum += exp;
out_values[row_first + idx] = exp;
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
out_values[row_first + idx] = buffer[i] / row_exp_sum;
} else {
out_values[row_first + idx] = static_cast<T>(0);
}
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
}
}

Expand Down Expand Up @@ -219,49 +181,36 @@ void FusedAttentionCsrKernel(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
}

/* Step1: SDD Matmul, reuse */
/* Step1: SDD Matmul, reuse matmul */
SparseCsrTensor sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SDDMM(false,
true,
static_cast<T>(1),
static_cast<T>(1 / std::sqrt(N)),
query,
key,
static_cast<T>(0),
&sdd_result);

/* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);

int buffer_size;
if (M < 128) {
buffer_size = (M + 32 - 1) / 32;
} else {
buffer_size = ((M + 128 - 1) / 128) * 4;
}

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
dim3 grid((total_row_num + 7) / 8);
dim3 block(WARP_SIZE, 8);

int batch_nnz = sdd_result.nnz() / batch_num;
AttnSoftmaxGpuKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
q_dim[1],
batch_nnz);

VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N),
q_dim[1],
batch_nnz);
});

/* Step3: DSD Matmul, reuse */
softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]}));
MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out);
#else
Expand Down
Expand Up @@ -37,7 +37,7 @@ def get_cuda_version():

@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.7"
)
class TestSparseAttentionAPI1(unittest.TestCase):

Expand Down