Skip to content

Commit

Permalink
[Sparse] optimize sparse attention
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Aug 1, 2022
1 parent 16506d8 commit cf2db1f
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 101 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/cusparse.h
Expand Up @@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)

#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/cusparse.h
Expand Up @@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)

#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
Expand Down
13 changes: 11 additions & 2 deletions paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
Expand Up @@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
}
}

inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCsrTensor& x) {
// TODO(zhouwei): will change to 'CUSPARSE_SPMM_CSR_ALG2' when support batch
return CUSPARSE_SPMM_CSR_ALG2;
}

inline cusparseSpMMAlg_t GetSpMMAlgorithm(const SparseCooTensor& x) {
return CUSPARSE_SPMM_ALG_DEFAULT;
}

/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/

template <typename T, typename IntT>
Expand Down Expand Up @@ -324,7 +333,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT,
GetSpMMAlgorithm(mat_a),
&buffer_size);
});

Expand All @@ -341,7 +350,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&beta,
out_descriptor.descriptor(),
gpu_type,
CUSPARSE_SPMM_ALG_DEFAULT,
GetSpMMAlgorithm(mat_a),
tmp_buffer_ptr);
});
}
Expand Down
23 changes: 8 additions & 15 deletions paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
Expand Up @@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
int row_nnz = static_cast<int>(out_crows[crow_idx + 1] - out_crows[crow_idx]);
if (row_nnz == 0) return;

int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;
T mul_result = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
T mul = 0;
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
mul += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;
T mul_sum = phi::funcs::warpReduceSum<T>(mul, 0xFFFFFFFF);

dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) *
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) *
out_values[row_first + idx] / scale;
}
}
Expand Down Expand Up @@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
int N = q_dim[q_rank - 1];
int batch_nnz = softmax.nnz() / batch_num;

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
dim3 grid((total_row_num + 7) / 8);
dim3 block(WARP_SIZE, 8);

AttnSoftmaxGpuGradKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
softmax.non_zero_crows().data<int64_t>(),
Expand Down
115 changes: 32 additions & 83 deletions paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
Expand Up @@ -26,30 +26,7 @@ limitations under the License. */
namespace phi {
namespace sparse {

#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \
case size: { \
constexpr int HINT = size; \
__VA_ARGS__(); \
break; \
}

#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \
[&] { \
const auto& __size__ = SIZE; \
switch (__size__) { \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for columns>512 "); \
} \
}()

template <typename T, int BufferSize>
template <typename T>
__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
const int64_t* x_cols,
const T* x_values,
Expand All @@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T* out_values,
int M,
int total_row_num,
float scale,
int num_heads,
int batch_nnz) {
// out = exp(x-x_max) / sum(exp(x-x_max))
Expand All @@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int row_nnz = static_cast<int>(x_crows[crow_idx + 1] - x_crows[crow_idx]);
if (row_nnz == 0) return;

T buffer[BufferSize] = {0};
int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE;

T max_val = -std::numeric_limits<T>::infinity();
for (int i = 0; i < kIteration; ++i) {
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
bool mask = false;
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

int col_idx = static_cast<int>(x_cols[row_first + idx]);

if (kp_mask != nullptr &&
kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) {
mask = true;
Expand All @@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
}

if (!mask) {
buffer[i] = x_values[row_first + idx] / scale;
if (buffer[i] > max_val) {
max_val = buffer[i];
T val = x_values[row_first + idx];
if (val > max_val) {
max_val = val;
}
out_values[row_first + idx] = val;
} else {
// Note corner case: when all elements of the row are masked, result
// may be wrong because of exp('-inf' - '-inf'), just ignore now.
out_values[row_first + idx] = -std::numeric_limits<T>::infinity();
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);

auto functor = phi::funcs::CudaExpFunctor<T>();
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
T exp = functor(buffer[i] - row_max_val);
exp_sum += exp;
buffer[i] = exp;
}
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
auto functor = phi::funcs::CudaExpFunctor<T>();
T exp = functor(out_values[row_first + idx] - row_max_val);
exp_sum += exp;
out_values[row_first + idx] = exp;
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);

for (int i = 0; i < kIteration; ++i) {
int idx = threadIdx.x + i * WARP_SIZE;
if (idx >= row_nnz) break;

if (buffer[i]) {
out_values[row_first + idx] = buffer[i] / row_exp_sum;
} else {
out_values[row_first + idx] = static_cast<T>(0);
}
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
}
}

Expand Down Expand Up @@ -219,49 +181,36 @@ void FusedAttentionCsrKernel(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
}

/* Step1: SDD Matmul, reuse */
/* Step1: SDD Matmul, reuse matmul */
SparseCsrTensor sdd_result;
EmptyLikeCsrKernel<T, Context>(dev_ctx, sparse_mask, &sdd_result);
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
sparse_blas.SDDMM(false,
true,
static_cast<T>(1),
static_cast<T>(1 / std::sqrt(N)),
query,
key,
static_cast<T>(0),
&sdd_result);

/* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
EmptyLikeCsrKernel<T, Context>(dev_ctx, sdd_result, softmax);

int buffer_size;
if (M < 128) {
buffer_size = (M + 32 - 1) / 32;
} else {
buffer_size = ((M + 128 - 1) / 128) * 4;
}

dim3 grid((total_row_num + 3) / 4);
dim3 block(WARP_SIZE, 4);
dim3 grid((total_row_num + 7) / 8);
dim3 block(WARP_SIZE, 8);

int batch_nnz = sdd_result.nnz() / batch_num;
AttnSoftmaxGpuKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
q_dim[1],
batch_nnz);

VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] {
AttnSoftmaxGpuKernel<T, KBufferSize><<<grid, block, 0, dev_ctx.stream()>>>(
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
std::sqrt(N),
q_dim[1],
batch_nnz);
});

/* Step3: DSD Matmul, reuse */
softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]}));
MatmulCsrDenseKernel<T, Context>(dev_ctx, *softmax, value, out);
#else
Expand Down
Expand Up @@ -37,7 +37,7 @@ def get_cuda_version():

@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11070,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.7"
)
class TestSparseAttentionAPI1(unittest.TestCase):

Expand Down

0 comments on commit cf2db1f

Please sign in to comment.