diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index e99009a70fc3b..a6520a0d48472 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -141,6 +141,15 @@ layout : x data_type : dtype +- api: fused_attention + args : (Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) + output : Tensor(out), Tensor(softmax) + kernel : + func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr} + layout : sparse_mask + intermediate : softmax + backward: fused_attention_grad + - api: masked_matmul args : (Tensor x, Tensor y, Tensor mask) output : Tensor(out) diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index 6ceedb0978121..5296d1b870bee 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -127,3 +127,10 @@ output : Tensor(x_grad) kernel : func : coo_values_grad{sparse_coo, dense-> sparse_coo} + +- backward_api: fused_attention_grad + forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) + args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) + output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) + kernel : + func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} diff --git a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h index 3d92674c92d6e..d817290882916 100644 --- a/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h @@ -232,7 +232,7 @@ class CuSparseDnMatDescriptor { PADDLE_ENFORCE_EQ(x.numel(), batch_size * M * N); if (batch_size > 1) { -#if CUDA_VERSION >= 11070 +#if CUDA_VERSION >= 11030 dev_ctx_.CusparseCall([&](cusparseHandle_t handle) { phi::dynload::cusparseDnMatSetStridedBatch( descriptor_, batch_size, M * N); diff --git a/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc new file mode 100644 index 0000000000000..d83951594f364 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/fused_attention_grad_kernel.cc @@ -0,0 +1,38 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void FusedAttentionCsrGradKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& softmax, + const DenseTensor& dout, + DenseTensor* dquery, + DenseTensor* dkey, + DenseTensor* dvalue) { + PD_THROW( + "Only support 'fused_attention' CPU backward kernel of SparseTensor now"); +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc new file mode 100644 index 0000000000000..bff25da8a5a37 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/fused_attention_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace sparse { + +template +void FusedAttentionCsrKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const DenseTensor& key_padding_mask, + const DenseTensor& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax) { + PD_THROW("Only support 'fused_attention' CPU kernel of SparseTensor now"); +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/fused_attention_grad_kernel.h b/paddle/phi/kernels/sparse/fused_attention_grad_kernel.h new file mode 100644 index 0000000000000..0a025d21f94f3 --- /dev/null +++ b/paddle/phi/kernels/sparse/fused_attention_grad_kernel.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +template +void FusedAttentionCsrGradKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& softmax, + const DenseTensor& dout, + DenseTensor* dquery, + DenseTensor* dkey, + DenseTensor* dvalue); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/fused_attention_kernel.h b/paddle/phi/kernels/sparse/fused_attention_kernel.h new file mode 100644 index 0000000000000..feff9d72e644c --- /dev/null +++ b/paddle/phi/kernels/sparse/fused_attention_kernel.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" + +namespace phi { +namespace sparse { + +template +void FusedAttentionCsrKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const DenseTensor& key_padding_mask, + const DenseTensor& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax); + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu new file mode 100644 index 0000000000000..b22965f663eb5 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu @@ -0,0 +1,150 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse/fused_attention_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/matmul_grad_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, + const T* out_values, + const T* dout_values, + T* dx_values, + int M, + int total_row_num, + float scale) { + // dx = (dout - sum(dout * out)) * out + int row = blockIdx.x * blockDim.y + threadIdx.y; + int non_zero_idx = threadIdx.x; + if (row >= total_row_num) return; + int cur_batch = row / M; + int crow_idx = cur_batch * (M + 1) + (row % M); + int cur_batch_offset = 0; + for (int i = 1; i < cur_batch + 1; ++i) { + cur_batch_offset += static_cast(out_crows[i * (M + 1) - 1]); + } + int row_first = cur_batch_offset + static_cast(out_crows[crow_idx]); + int row_nnz = static_cast(out_crows[crow_idx + 1] - out_crows[crow_idx]); + if (row_nnz == 0) return; + + int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; + T mul_result = 0; + for (int i = 0; i < kIteration; ++i) { + int idx = non_zero_idx + i * WARP_SIZE; + if (idx >= row_nnz) break; + + mul_result += out_values[row_first + idx] * dout_values[row_first + idx]; + } + T sum = phi::funcs::warpReduceSum(mul_result, 0xFFFFFFFF); + + for (int i = 0; i < kIteration; ++i) { + int idx = non_zero_idx + i * WARP_SIZE; + if (idx >= row_nnz) break; + + dx_values[row_first + idx] = (dout_values[row_first + idx] - sum) * + out_values[row_first + idx] / scale; + } +} + +template +void FusedAttentionCsrGradKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& softmax, + const DenseTensor& dout, + DenseTensor* dquery, + DenseTensor* dkey, + DenseTensor* dvalue) { +#if CUDA_VERSION >= 11070 + /* Step1: Forward: softmax{CSR} * value{Dense} -> out{Dense}, reuse */ + SparseCsrTensor dsoftmax; + CsrDenseMatmulGradKernel( + dev_ctx, softmax, value, dout, &dsoftmax, dvalue); + + /* Step2: Calculate grad of sdd_result, manualy not reuse */ + SparseCsrTensor d_sdd_result; + EmptyLikeCsrKernel(dev_ctx, dsoftmax, &d_sdd_result); + auto q_dim = query.dims(); + auto q_rank = q_dim.size(); + + int total_row_num = 1; + for (int i = 0; i < q_rank - 1; ++i) { + total_row_num *= q_dim[i]; + } + int M = q_dim[q_rank - 2]; + int N = q_dim[q_rank - 1]; + + dim3 grid((total_row_num + 3) / 4); + dim3 block(WARP_SIZE, 4); + + AttnSoftmaxGpuGradKernel<<>>( + softmax.non_zero_crows().data(), + softmax.non_zero_elements().data(), + dsoftmax.mutable_non_zero_elements()->data(), + d_sdd_result.mutable_non_zero_elements()->data(), + M, + total_row_num, + std::sqrt(N)); + + /* Step3: Forward: query{Dense} * key'{Dense} -> sdd_result{SparseCsr} */ + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + // dquery{Dense} = d_sdd_result{SparseCsr} * key{Dense} // + dquery->Resize(query.dims()); + dev_ctx.template Alloc(dquery); + sparse_blas.SPMM(false, + false, + static_cast(1.f), + d_sdd_result, + key, + static_cast(0.f), + dquery); + + // dkey{Dense} = d_sdd_result'{SparseCsr} * query{Dense} // + dkey->Resize(key.dims()); + dev_ctx.template Alloc(dkey); + sparse_blas.SPMM(true, + false, + static_cast(1.f), + d_sdd_result, + query, + static_cast(0.f), + dkey); +#else + PADDLE_THROW( + phi::errors::Unimplemented("backward of 'sparse.nn.functional.attention' " + "use 'cusparseCsrSetStridedBatch', which is " + "completed supported from CUDA 11.7")); +#endif +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(fused_attention_csr_grad, + GPU, + ALL_LAYOUT, + phi::sparse::FusedAttentionCsrGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu new file mode 100644 index 0000000000000..8af0505fc3117 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu @@ -0,0 +1,253 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/fused_attention_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" +#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h" +#include "paddle/phi/kernels/sparse/empty_kernel.h" +#include "paddle/phi/kernels/sparse/matmul_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" + +namespace phi { +namespace sparse { + +#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \ + case size: { \ + constexpr int HINT = size; \ + __VA_ARGS__(); \ + break; \ + } + +#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \ + [&] { \ + const auto& __size__ = SIZE; \ + switch (__size__) { \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \ + PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for columns>512 "); \ + } \ + }() + +template +__global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, + const int64_t* x_cols, + const T* x_values, + const T* kp_mask, + const T* attn_mask, + T* out_values, + int M, + int total_row_num, + float scale, + int num_heads) { + // out = exp(x-x_max) / sum(exp(x-x_max)) + int row = blockIdx.x * blockDim.y + threadIdx.y; + int non_zero_idx = threadIdx.x; + if (row >= total_row_num) return; + int cur_batch = row / M; + int cur_row = row % M; + int crow_idx = cur_batch * (M + 1) + cur_row; + int cur_batch_offset = 0; + for (int i = 1; i < cur_batch + 1; ++i) { + cur_batch_offset += static_cast(x_crows[i * (M + 1) - 1]); + } + int row_first = cur_batch_offset + static_cast(x_crows[crow_idx]); + int row_nnz = static_cast(x_crows[crow_idx + 1] - x_crows[crow_idx]); + if (row_nnz == 0) return; + + T buffer[BufferSize] = {0}; + int kIteration = (row_nnz + WARP_SIZE - 1) / WARP_SIZE; + + T max_val = -std::numeric_limits::infinity(); + for (int i = 0; i < kIteration; ++i) { + bool mask = false; + int idx = non_zero_idx + i * WARP_SIZE; + if (idx >= row_nnz) break; + + int col_idx = static_cast(x_cols[row_first + idx]); + + if (kp_mask != nullptr && + kp_mask[(cur_batch / num_heads) * M + col_idx] == 0) { + mask = true; + } + if (attn_mask != nullptr && attn_mask[cur_row * M + col_idx] == 0) { + mask = true; + } + + if (!mask) { + buffer[i] = x_values[row_first + idx] / scale; + if (buffer[i] > max_val) { + max_val = buffer[i]; + } + } + } + T row_max_val = phi::funcs::warpReduceMax(max_val, 0xFFFFFFFF); + + auto functor = phi::funcs::CudaExpFunctor(); + T exp_sum = 0; + for (int i = 0; i < kIteration; ++i) { + int idx = non_zero_idx + i * WARP_SIZE; + if (idx >= row_nnz) break; + + if (buffer[i]) { + T exp = functor(buffer[i] - row_max_val); + exp_sum += exp; + buffer[i] = exp; + } + } + T row_exp_sum = phi::funcs::warpReduceSum(exp_sum, 0xFFFFFFFF); + + for (int i = 0; i < kIteration; ++i) { + int idx = non_zero_idx + i * WARP_SIZE; + if (idx >= row_nnz) break; + + if (buffer[i]) { + out_values[row_first + idx] = buffer[i] / row_exp_sum; + } else { + out_values[row_first + idx] = static_cast(0); + } + } +} + +template +void FusedAttentionCsrKernel(const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const DenseTensor& key_padding_mask, + const DenseTensor& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax) { +#if CUDA_VERSION >= 11070 + /* Check Shape */ + auto q_dim = query.dims(); + auto q_rank = q_dim.size(); + + int total_row_num = 1; + for (int i = 0; i < q_rank - 1; ++i) { + total_row_num *= q_dim[i]; + } + int M = q_dim[q_rank - 2]; + int N = q_dim[q_rank - 1]; + + PADDLE_ENFORCE_EQ(query.dims().size(), + 4, + phi::errors::InvalidArgument(" 'query' must be 4D Tensor")); + PADDLE_ENFORCE_EQ(key.dims().size(), + 4, + phi::errors::InvalidArgument(" 'key' must be 4D Tensor")); + PADDLE_ENFORCE_EQ(value.dims().size(), + 4, + phi::errors::InvalidArgument(" 'value' must be 4D Tensor")); + + PADDLE_ENFORCE_EQ( + key_padding_mask.dims().size(), + 2, + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + PADDLE_ENFORCE_EQ( + key_padding_mask.dims()[0], + q_dim[0], + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + PADDLE_ENFORCE_EQ( + key_padding_mask.dims()[1], + M, + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + + PADDLE_ENFORCE_EQ(attn_mask.dims().size(), + 2, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + PADDLE_ENFORCE_EQ(attn_mask.dims()[0], + M, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + PADDLE_ENFORCE_EQ(attn_mask.dims()[1], + M, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + + /* Step1: SDD Matmul, reuse */ + SparseCsrTensor sdd_result; + EmptyLikeCsrKernel(dev_ctx, sparse_mask, &sdd_result); + auto sparse_blas = phi::funcs::sparse::GetSparseBlas(dev_ctx); + sparse_blas.SDDMM(false, + true, + static_cast(1), + query, + key, + static_cast(0), + &sdd_result); + + /* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */ + EmptyLikeCsrKernel(dev_ctx, sdd_result, softmax); + + int buffer_size; + if (M < 128) { + buffer_size = (M + 32 - 1) / 32; + } else { + buffer_size = ((M + 128 - 1) / 128) * 4; + } + + dim3 grid((total_row_num + 3) / 4); + dim3 block(WARP_SIZE, 4); + + VISIT_ATTN_SFOTMAX(buffer_size, "AttnSoftmaxGpuKernel", [&] { + AttnSoftmaxGpuKernel<<>>( + sdd_result.non_zero_crows().data(), + sdd_result.non_zero_cols().data(), + sdd_result.non_zero_elements().data(), + key_padding_mask.data(), + attn_mask.data(), + softmax->mutable_non_zero_elements()->data(), + M, + total_row_num, + std::sqrt(N), + q_dim[1]); + }); + + /* Step3: DSD Matmul, reuse */ + softmax->set_dims(phi::make_ddim({q_dim[0], q_dim[1], q_dim[2], q_dim[2]})); + CsrDenseMatmulKernel(dev_ctx, *softmax, value, out); +#else + PADDLE_THROW( + phi::errors::Unimplemented("forward of 'sparse.nn.functional.attention' " + "use 'cusparseCsrSetStridedBatch', which is " + "completed supported from CUDA 11.7")); +#endif +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(fused_attention_csr, + GPU, + ALL_LAYOUT, + phi::sparse::FusedAttentionCsrKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); +} diff --git a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu index 9357bbd2ad083..69cd4bac0c763 100644 --- a/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/matmul_kernel.cu @@ -76,7 +76,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx, out_dim_vec[y_ndims - 1] = ydim_vec[y_ndims - 1]; MetaTensor meta_out(out); meta_out.set_dims(phi::make_ddim(out_dim_vec)); - meta_out.set_dtype(x.non_zero_elements().dtype()); + meta_out.set_dtype(y.dtype()); dev_ctx.template Alloc(out); diff --git a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu index 9c9f5cfbca545..ee0671b333f81 100644 --- a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu @@ -52,8 +52,9 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows, int idx = non_zero_idx + i * warpSize; if (idx >= row_nnz) break; - if (max_val < x_values[row_first + idx]) { - max_val = x_values[row_first + idx]; + T val = x_values[row_first + idx]; + if (val > max_val) { + max_val = val; } } T row_max_val = phi::funcs::warpReduceMax(max_val, 0xFFFFFFFF); diff --git a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py new file mode 100644 index 0000000000000..a25ece822bb65 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import re +import copy +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" +) +class TestSparseAttentionAPI1(unittest.TestCase): + + def setUp(self): + self.batch_size = 16 + self.num_heads = 16 + self.seq_len = 128 + self.head_dim = 16 + self.dtype = 'float64' + + def test_dygraph(self): + with _test_eager_guard(): + self.shape = [ + self.batch_size, self.num_heads, self.seq_len, self.head_dim + ] + query = paddle.rand(self.shape, self.dtype) + key = paddle.rand(self.shape, self.dtype) + value = paddle.rand(self.shape, self.dtype) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + mask = paddle.nn.functional.dropout(paddle.ones( + [self.seq_len, self.seq_len]), + mode='downscale_in_infer') + mask = mask.expand( + [self.batch_size, self.num_heads, self.seq_len, self.seq_len]) + sp_mask = mask.reshape([-1, self.seq_len, + self.seq_len]).to_sparse_csr() + + kp_mask = paddle.randint( + 0, 2, [self.batch_size, self.seq_len]).astype(self.dtype) + attn_mask = paddle.randint( + 0, 2, [self.seq_len, self.seq_len]).astype(self.dtype) + + sdd = paddle.matmul(query, key, False, True) / math.sqrt( + float(self.head_dim)) + sdd = sdd + ( + (mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9 + softmax = paddle.nn.functional.softmax(sdd) + output = paddle.matmul(softmax, value) + output.backward() + + query_cp = copy.deepcopy(query) + key_cp = copy.deepcopy(key) + value_cp = copy.deepcopy(value) + + query_cp.stop_gradient = False + key_cp.stop_gradient = False + value_cp.stop_gradient = False + + output_cp = paddle.incubate.sparse.nn.functional.attention( + query_cp, key_cp, value_cp, sp_mask, kp_mask, attn_mask) + output_cp.backward() + + self.assertTrue(np.allclose(output_cp.numpy(), output.numpy())) + self.assertTrue( + np.allclose(query_cp.grad.numpy(), query.grad.numpy())) + self.assertTrue(np.allclose(key_cp.grad.numpy(), key.grad.numpy())) + self.assertTrue( + np.allclose(value_cp.grad.numpy(), value.grad.numpy())) + + +class TestSparseAttentionAPI2(TestSparseAttentionAPI1): + + def setUp(self): + self.batch_size = 16 + self.num_heads = 16 + self.seq_len = 128 + self.head_dim = 32 + self.dtype = 'float64' + + +class TestSparseAttentionAPI3(TestSparseAttentionAPI1): + + def setUp(self): + self.batch_size = 16 + self.num_heads = 16 + self.seq_len = 512 + self.head_dim = 16 + self.dtype = 'float64' + + +class TestSparseAttentionAPI4(TestSparseAttentionAPI1): + + def setUp(self): + self.batch_size = 16 + self.num_heads = 16 + self.seq_len = 512 + self.head_dim = 32 + self.dtype = 'float64' + + +class TestSparseAttentionAPI5(TestSparseAttentionAPI1): + + def setUp(self): + self.batch_size = 16 + self.num_heads = 16 + self.seq_len = 512 + self.head_dim = 64 + self.dtype = 'float64' + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/sparse/nn/functional/__init__.py b/python/paddle/incubate/sparse/nn/functional/__init__.py index af5636a622a9a..21939eeb1a4f9 100644 --- a/python/paddle/incubate/sparse/nn/functional/__init__.py +++ b/python/paddle/incubate/sparse/nn/functional/__init__.py @@ -14,6 +14,7 @@ from .conv import conv3d # noqa: F401 from .conv import subm_conv3d # noqa: F401 +from .transformer import attention # noqa: F401 from .pooling import max_pool3d # noqa: F401 from .activation import relu # noqa: F401 from .activation import softmax # noqa: F401 @@ -24,4 +25,5 @@ 'max_pool3d', 'relu', 'softmax', + 'attention', ] diff --git a/python/paddle/incubate/sparse/nn/functional/activation.py b/python/paddle/incubate/sparse/nn/functional/activation.py index 12d44063e0015..dc2969424086e 100644 --- a/python/paddle/incubate/sparse/nn/functional/activation.py +++ b/python/paddle/incubate/sparse/nn/functional/activation.py @@ -14,7 +14,7 @@ __all__ = [] -from paddle import _C_ops, in_dynamic_mode +from paddle import _C_ops from paddle.fluid.framework import dygraph_only diff --git a/python/paddle/incubate/sparse/nn/functional/transformer.py b/python/paddle/incubate/sparse/nn/functional/transformer.py new file mode 100644 index 0000000000000..1686487ef7dc1 --- /dev/null +++ b/python/paddle/incubate/sparse/nn/functional/transformer.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [] + +from paddle import _C_ops +from paddle.fluid.framework import dygraph_only + + +@dygraph_only +def attention(query, + key, + value, + sparse_mask, + key_padding_mask, + attn_mask, + name=None): + """ + Note: + This API is only used from ``CUDA 11.7`` . + + SparseCsrTensor is used to store the intermediate result of Attention matrix + in Transformer module, which can reduce memory usage and improve performance. + ``sparse_mask`` express the sparse layout in CSR format. + The calculation equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The shape of the three parameters are: + The dimensions of these three parameters are: [batch_size, num_heads, seq_len, head_dim]. + ``d`` represents ``head_dim`` . + + Args: + query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64. + key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64. + value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64. + sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. shape of `crows` is + [batch_size, num_heads, seq_len + 1], shape of `cols` is [batch_size, num_heads, nnz]. + dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64. + key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module. + 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. + attn_mask(DenseTensor):The attention mask tensor in the Attention module. + 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + 4D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. dtype is same with input. + + Examples: + .. code-block:: python + + import paddle + + batch_size = 16 + num_heads = 16 + seq_len = 512 + head_dim = 32 + + query = paddle.rand([batch_size, num_heads, seq_len, head_dim]) + key = paddle.rand([batch_size, num_heads, seq_len, head_dim]) + value = paddle.rand([batch_size, num_heads, seq_len, head_dim]) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len]) + sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr() + + kp_mask = paddle.randint(0, 2, [batch_size, seq_len]) + attn_mask = paddle.randint(0, 2, [seq_len, seq_len]) + + output = paddle.incubate.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask) + output.backward() + """ + return _C_ops.final_state_sparse_fused_attention(query, key, value, + sparse_mask, + key_padding_mask, + attn_mask)