From c3f5c8dc5cd1fb3b641d0912d7ec548c56db0f23 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Tue, 15 Nov 2022 01:10:35 +0000 Subject: [PATCH] Add mem efficient backward (#88856) # Registers the derivative for mem efficient backward - Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32 - I also made updates based off of Xformer main branch and flash-attention cutlass branch. - This will enable the fused backward to be called for scaled dot product attention Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 5 + .../native/transformers/cuda/attention.cu | 16 +- .../transformers/cuda/attention_backward.cu | 261 ++++++++++++++++++ .../transformers/cuda/flash_attn/fmha_api.cpp | 4 + .../attention_backward_generic.cu | 166 ----------- .../attention_forward_generic.cu | 232 ---------------- .../cuda/mem_eff_attention/find_default_mma.h | 7 +- .../cuda/mem_eff_attention/kernel_backward.h | 250 +++++++++++------ .../ATen/native/transformers/cuda/sdp_utils.h | 12 +- test/test_transformers.py | 44 ++- tools/autograd/derivatives.yaml | 7 +- .../_internal/common_methods_invocations.py | 4 +- 12 files changed, 501 insertions(+), 507 deletions(-) create mode 100644 aten/src/ATen/native/transformers/cuda/attention_backward.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index de087c0b8a89..9572ccc56653 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13287,6 +13287,11 @@ dispatch: CUDA: _efficient_attention_forward +- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index f65fedd6d795..46543d4663fa 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -746,7 +746,9 @@ std::tuple flash_attention_helper_dense_unpacked( std::tuple mem_eff_helper( const Tensor& query, const Tensor& key, - const Tensor& value){ + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) @@ -754,16 +756,18 @@ std::tuple mem_eff_helper( Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - Tensor attention = std::get<0>(at::_efficient_attention_forward( + Tensor attention, log_sumexp; + std::tie(attention, log_sumexp) = at::_efficient_attention_forward( q_t, k_t, v_t, c10::nullopt, c10::nullopt, c10::nullopt, - false, - false)).transpose(1,2); - return std::make_tuple(attention, Tensor()); + compute_log_sumexp, + is_causal); + attention = attention.transpose(1,2); + return std::make_tuple(std::move(attention), Tensor()); } std::tuple _scaled_dot_product_attention_forward_cuda( @@ -776,7 +780,7 @@ std::tuple _scaled_dot_product_attention_forward_cuda( case sdp::SDPBackend::flash_attention: return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value); + return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); case sdp::SDPBackend::math: return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); default: diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu new file mode 100644 index 000000000000..af005b2669b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -0,0 +1,261 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef USE_FLASH_ATTENTION +#include +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ + } + +#define DISPATCH_MAXK(func) \ + { \ + const auto maxK = std::max(query.size(3), value.size(3)); \ + if (maxK <= 64) { \ + constexpr int kMaxK = 64; \ + func(); \ + } else if (maxK <= 128) { \ + constexpr int kMaxK = 128; \ + func(); \ + } else { \ + constexpr int kMaxK = std::numeric_limits::max(); \ + func(); \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = \ + AttentionBackwardKernel; \ + bool isAligned = \ + (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })); \ + } + +namespace at { + +namespace native { + +std::tuple _efficient_attention_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + const at::Tensor& out, + bool causal) { + #if defined(USE_FLASH_ATTENTION) + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + // ndim + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + TORCH_CHECK(query.dim() == 4); + + // batch size + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // seqlen + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + TORCH_CHECK(query.size(2) == grad_out_.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + TORCH_CHECK(value.size(3) == grad_out_.size(3)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::cuda::CUDAGuard device_guard(query.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t nH = query.size(2); + int64_t K = query.size(3); + + // It does not make sense to use that in practice, + // but let's still make sure we are correct + // As we iterate through keys first, we skip + // keys with no query associated, so they are not + // initialized + bool grad_kv_needs_init = causal && N > M; + at::Tensor grad_q, grad_k, grad_v; + if (!grad_kv_needs_init && query.size(1) == key.size(1) && + query.size(3) == value.size(3) && + query.storage().is_alias_of(key.storage()) && + query.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty_like(query); + grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); + grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + } + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)) + .sum(-1) + .transpose(-2, -1) + .contiguous(); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == nH); + TORCH_INTERNAL_ASSERT(delta.size(2) == M); + + typename Kernel::Params p; + p.query_ptr = (scalar_t*)query.data_ptr(); + p.key_ptr = (scalar_t*)key.data_ptr(); + p.value_ptr = (scalar_t*)value.data_ptr(); + p.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); + p.output_ptr = (scalar_t*)out.data_ptr(); + p.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); + p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + p.delta_ptr = (float*)delta.data_ptr(); + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = query.size(1); + p.num_keys = key.size(1); + p.num_batches = B; + p.num_heads = nH; + p.causal = causal; + + ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); + p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; + TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); + TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); + TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + + Kernel::check_supported(p); + + constexpr auto kernel_fn = attention_kernel_backward_batched; + + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + // second syntax resulted in the error below on windows + // error C3495: 'kernel_fn': a simple capture must be a variable + // with automatic storage duration declared + // in the reaching scope of the lambda +#ifdef _WIN32 + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + TORCH_INTERNAL_ASSERT( + attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, + "Something went wrong in the build process"); +#else + auto checkBinaryArchMatches = [&]() { + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); +#endif + + kernel_fn<<>>(p); + }; + + DISPATCH_KERNEL( + query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); + #endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index a8d6110e951d..6c86e1ff63b0 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -29,6 +29,7 @@ #ifdef USE_FLASH_ATTENTION #include #include +#include #include #include @@ -185,6 +186,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.get_device()}; + auto opts = q.options(); auto o = at::empty({ total_q, num_heads, head_size }, opts); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu deleted file mode 100644 index 07c14ad8195d..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu +++ /dev/null @@ -1,166 +0,0 @@ -#include - -#define DISPATCH_MAXK(func) \ - { \ - const auto maxK = std::max(query.size(2), value.size(2)); \ - if (maxK <= 64) { \ - constexpr int kMaxK = 64; \ - func(); \ - } else if (maxK <= 128) { \ - constexpr int kMaxK = 128; \ - func(); \ - } else { \ - constexpr int kMaxK = std::numeric_limits::max(); \ - func(); \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(1) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ - } - -namespace { -std::tuple -mem_efficient_attention_backward_cutlass( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& logsumexp, - const at::Tensor& out, - bool causal) { - TORCH_CHECK(query.dim() == grad_out_.dim()); - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == 3); - - TORCH_CHECK(query.size(0) == grad_out_.size(0)); - TORCH_CHECK(query.size(1) == grad_out_.size(1)); - TORCH_CHECK(value.size(2) == grad_out_.size(2)); - - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - - // handle potentially non-contiguous grad_out through a copy - auto grad_out = grad_out_.contiguous(); - - CHECK_NOSPARSE_CONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(value); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); - - at::cuda::CUDAGuard device_guard(query.device()); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t K = query.size(2); - - // It does not make sense to use that in practice, - // but let's still make sure we are correct - // As we iterate through keys first, we skip - // keys with no query associated, so they are not - // initialized - bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q = at::empty_like(query); - at::Tensor grad_k = - grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - at::Tensor grad_v = - grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - - // TODO: Fuse this into a kernel? - // This is a bottleneck for smaller sequences (M <= 128) - auto delta = Kernel::kKernelComputesDelta - ? at::empty({B, M}, query.options().dtype(at::ScalarType::Float)) - : (grad_out.to(at::kFloat) * out.to(at::kFloat)).sum(-1); - TORCH_INTERNAL_ASSERT(delta.size(0) == B); - TORCH_INTERNAL_ASSERT(delta.size(1) == M); - - typename Kernel::Params params; - params.query_ptr = (scalar_t*)query.data_ptr(); - params.key_ptr = (scalar_t*)key.data_ptr(); - params.value_ptr = (scalar_t*)value.data_ptr(); - params.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); - params.output_ptr = (scalar_t*)out.data_ptr(); - params.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); - params.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); - params.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); - params.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); - params.delta_ptr = (float*)delta.data_ptr(); - params.head_dim = query.size(2); - params.head_dim_value = value.size(2); - params.num_queries = query.size(1); - params.num_keys = key.size(1); - params.num_batches = B; - params.causal = causal; - Kernel::check_supported(params); - - constexpr auto kernel_fn = attention_kernel_backward_batched; - - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - } - - auto checkBinaryArchMatches = [&]() { - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; - }; - TORCH_INTERNAL_ASSERT( - checkBinaryArchMatches(), "Something went wrong in the build process"); - - kernel_fn<<>>( - params); - }; - - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); -} // namespace - -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_cutlass"), -// TORCH_FN(mem_efficient_attention_backward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu deleted file mode 100644 index 59b3637c8a43..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - - -#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ - { \ - if (VALUE_HEAD_DIM <= 64) { \ - constexpr bool kIs64x64 = true; \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kIs64x64 = false; \ - if (VALUE_HEAD_DIM <= 128) { \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kSingleValueIteration = false; \ - FN(); \ - } \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_BLOCKSIZE( \ - VALUE.size(-1), ([&]() { \ - static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \ - static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - true, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - /* Run a more efficient kernel (with `isAligned=True`) \ - if memory is correctly aligned*/ \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \ - KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \ - VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \ - /* TODO: Should we warn or log somewhere when we use a \ - less efficient kernel due to wrong alignment? */ \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - kIsAligned, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - FUNC(); \ - })) \ - })) \ - })); \ - })); \ - } - -namespace { -/* - There are 2 modes for using this function. - (Mode BMHK) With all the heads having the same seqlen - (Mode 1MHK) `batch=1` with all tokens across batches concatenated -*/ -std::tuple efficient_attention_forward_cutlass( - const at::Tensor& query, // [b, seqlen, num_heads, K] - const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] - // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the - // position of the first query token for batch $b - const c10::optional& cu_seqlens_q, - // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the - // position of the first key token for batch $b - const c10::optional& cu_seqlens_k, - // (Mode 1MHK only) Maximum sequence length across batches - const c10::optional max_seqlen_q_, - bool compute_logsumexp, - bool causal) { - TORCH_CHECK(query.dim() == 4); - TORCH_CHECK(key.dim() == 4); - TORCH_CHECK(value.dim() == 4); - - // Batch sizes - TORCH_CHECK(query.size(0) == key.size(0)); - TORCH_CHECK(query.size(0) == value.size(0)); - - // Sequence length - TORCH_CHECK(key.size(1) == value.size(1)); - - // Num heads - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(2) == value.size(2)); - - // Embedding per head - TORCH_CHECK(query.size(3) == key.size(3)); - - int64_t max_seqlen_q, max_seqlen_k; - TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); - if (cu_seqlens_q.has_value()) { - TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); - TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); - TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); - TORCH_CHECK(max_seqlen_q_.has_value()); - max_seqlen_q = *max_seqlen_q_; - max_seqlen_k = 0; // Will be set inside the kernel - } else { - max_seqlen_q = query.size(1); - max_seqlen_k = key.size(1); - } - - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - - at::cuda::CUDAGuard device_guard(query.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t num_heads = query.size(-2); - int64_t K = query.size(-1); - int64_t Kv = value.size(-1); - - at::Tensor res; - at::Tensor logsumexp; - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - res = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - - // NOTE: Should be aligned (by padding) in case M is - // not a good number for loading during backward - constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; - logsumexp = at::empty( - {B, - num_heads, - compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, - query.options().dtype(at::ScalarType::Float)); - - typename Kernel::Params p; - p.query_ptr = (scalar_t*)query.data_ptr(); - p.key_ptr = (scalar_t*)key.data_ptr(); - p.value_ptr = (scalar_t*)value.data_ptr(); - p.logsumexp_ptr = compute_logsumexp - ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() - : nullptr; - at::Tensor output_accum; - if (Kernel::kNeedsOutputAccumulatorBuffer) { - output_accum = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - p.output_accum_ptr = - (typename Kernel::output_accum_t*)output_accum.data_ptr(); - } else { - p.output_accum_ptr = nullptr; - } - p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); - - if (cu_seqlens_q.has_value()) { - p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); - p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); - } - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ - } - - p.num_heads = num_heads; - p.head_dim = query.size(3); - p.head_dim_value = value.size(3); - p.num_queries = max_seqlen_q; - p.num_keys = max_seqlen_k; - p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; - p.causal = causal; - - ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); - - constexpr auto kernel_fn = attention_kernel_batched; - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - AT_CUDA_CHECK(cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - } - Kernel::check_supported(p); - kernel_fn<<>>(p); - }; - // Dispatch to the right kernel - DISPATCH_KERNEL(query, key, value, ([&]() { - launchKernel(Kernel{}, computeCapability); - })); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); -} -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_cutlass"), -// TORCH_FN(efficient_attention_forward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h index 399593fd0957..b0e7106f3cfc 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h @@ -1,15 +1,16 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instanciate to run a GEMM with various parameters (see + datastructures to instantiate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instanciation priority rules, it will only create an MmaMultiStage with + instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and - `default_mma_core.h` files and wrapped this template to allow our usecase. + `default_mma_core.h` files and wrapped this template to allow our use case. This is really only for the FastF32 case - aka using TensorCores with fp32. */ +#pragma once #include #include diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index e25701a7588a..c9652c40d38e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1,7 +1,5 @@ #pragma once - #include -#include #include #include @@ -75,46 +73,113 @@ struct AttentionBackwardKernel { struct Params { // Input tensors - scalar_t* query_ptr; // [num_queries, head_dim] - scalar_t* key_ptr; // [num_keys, head_dim] - scalar_t* value_ptr; // [num_keys, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_queries] - scalar_t* output_ptr; // [num_queries, head_dim_value] - scalar_t* grad_output_ptr; // [num_queries, head_dim_value] - accum_t* delta_ptr; // [num_queries] + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [Mq, nH] // Output tensors - scalar_t* grad_query_ptr; // [num_queries, head_dim] - scalar_t* grad_key_ptr; // [num_keys, head_dim] - scalar_t* grad_value_ptr; // [num_keys, head_dim_value] + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] // Dimensions/strides int32_t head_dim; int32_t head_dim_value; int32_t num_queries; int32_t num_keys; - int32_t num_batches; + int32_t num_heads; bool causal; - __device__ void advance_batches(int32_t batch_id) { + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int32_t num_batches; + + int64_t gO_strideB; + int64_t gQ_strideB; + int64_t gK_strideB; + int64_t gV_strideB; + int64_t gO_strideH; + int64_t gQ_strideH; + int64_t gK_strideH; + int64_t gV_strideH; + + CUTLASS_DEVICE void advance_to_block() { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - query_ptr += batch_id * head_dim * num_queries; - key_ptr += batch_id * head_dim * num_keys; - value_ptr += batch_id * head_dim_value * num_keys; - logsumexp_ptr += batch_id * lse_dim; - output_ptr += batch_id * head_dim_value * num_queries; - grad_output_ptr += batch_id * head_dim_value * num_queries; - delta_ptr += batch_id * num_queries; - - grad_query_ptr += batch_id * head_dim * num_queries; - grad_key_ptr += batch_id * head_dim * num_keys; - grad_value_ptr += batch_id * head_dim_value * num_keys; + int32_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += (batch_id * num_heads + head_id) * num_queries; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + + gO_strideM = warp_uniform(gO_strideM); + gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); + q_strideM = warp_uniform(q_strideM); + k_strideM = warp_uniform(k_strideM); + v_strideM = warp_uniform(v_strideM); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); } __host__ dim3 getBlocksGrid() const { - return dim3(1, 1, num_batches); + return dim3(1, num_heads, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); @@ -179,7 +244,6 @@ struct AttentionBackwardKernel { attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue - with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = @@ -225,7 +289,6 @@ struct AttentionBackwardKernel { struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul - Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ @@ -601,7 +664,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - __device__ __forceinline__ void clear() { + CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } @@ -614,14 +677,14 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); TORCH_CHECK( - p.head_dim % kMinimumAlignment == 0, - "query/key is not correctly aligned"); + p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); TORCH_CHECK( - p.head_dim_value % kMinimumAlignment == 0, - "value is not correctly aligned"); + p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); + TORCH_CHECK( + p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); } - static __device__ void kernel(Params& p_) { + static CUTLASS_DEVICE void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot register const Params p = p_; @@ -658,7 +721,7 @@ struct AttentionBackwardKernel { __syncthreads(); } - OutputFragments output_frags; + OutputFragments register output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -695,7 +758,7 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ void loadDi( + static CUTLASS_DEVICE void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -710,7 +773,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void processBlockIJ( + static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -718,9 +781,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int32_t warp_id = threadIdx.y; - int32_t lane_id = threadIdx.x; + int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int8_t warp_id = warp_uniform(threadIdx.y); + int8_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -734,8 +797,8 @@ struct AttentionBackwardKernel { auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -747,8 +810,8 @@ struct AttentionBackwardKernel { }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); @@ -757,8 +820,8 @@ struct AttentionBackwardKernel { }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); @@ -770,14 +833,14 @@ struct AttentionBackwardKernel { }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -803,16 +866,16 @@ struct AttentionBackwardKernel { // k_j typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -893,14 +956,14 @@ struct AttentionBackwardKernel { num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value + col, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -951,16 +1014,16 @@ struct AttentionBackwardKernel { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -1057,16 +1120,16 @@ struct AttentionBackwardKernel { num_keys_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.head_dim}, - p.grad_query_ptr + query_start * p.head_dim + col, + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); }; // k_j typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1153,8 +1216,8 @@ struct AttentionBackwardKernel { num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim + col, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); @@ -1162,8 +1225,8 @@ struct AttentionBackwardKernel { // q_i typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1236,15 +1299,15 @@ struct AttentionBackwardKernel { kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; auto thread_id = get_thread_id(); typename MatmulQK::Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); @@ -1259,7 +1322,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void writeFragsToGmem( + static CUTLASS_DEVICE void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1268,8 +1331,8 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : std::min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, get_thread_id()); accumulateInGmem( @@ -1279,8 +1342,8 @@ struct AttentionBackwardKernel { true); typename MatmulGradK::OutputTileIterator outputK_it( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, get_thread_id()); @@ -1292,7 +1355,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void accumulateInGmem( + static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1334,7 +1397,9 @@ struct AttentionBackwardKernel { } template - static __device__ void computeDelta(Params const& p, int32_t query_start) { + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1349,13 +1414,15 @@ struct AttentionBackwardKernel { bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; - const __restrict__ AccessType* grad_output_ptr = - reinterpret_cast( - p.grad_output_ptr + (query_start + laneRow) * p.head_dim_value + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); - const __restrict__ AccessType* output_ptr = - reinterpret_cast( - p.output_ptr + (query_start + laneRow) * p.head_dim_value + + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); static constexpr int64_t kMaxIters = @@ -1430,13 +1497,13 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ int8_t get_lane_id() { + static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; } - static __device__ __forceinline__ int8_t get_warp_id() { + static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; } - static __device__ __forceinline__ int16_t get_thread_id() { + static CUTLASS_DEVICE int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } }; @@ -1457,8 +1524,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) #define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ _ATTENTION_KERNEL_BACKWARD_BEGIN( \ AttentionBackwardKernel) \ - auto batch_id = blockIdx.z; \ - p.advance_batches(batch_id); \ + p.advance_to_block(); \ Kernel::kernel(p); \ _ATTENTION_KERNEL_BACKWARD_END(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 564adb2d51ea..e9f3d5029aa8 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -62,6 +62,15 @@ inline bool check_for_attn_weights(sdp_params params, bool debug) { } return true; } + +inline bool check_for_non_zero_dropout(sdp_params params, bool debug) { + if (params.dropout != 0.0) { + TORCH_CHECK(!debug, "Mem_efficient does not support non_zero dropout. Dropout_p: ", params.dropout); + return false; + } + return true; +} + inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { if (!params.query.is_nested()) { return true; @@ -230,7 +239,8 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, - check_for_seq_len_1_nested_tensor}; + check_for_seq_len_1_nested_tensor, + check_for_non_zero_dropout}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/test_transformers.py b/test/test_transformers.py index a9d0d960fb9a..c86b89bed5ef 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -21,8 +21,11 @@ TEST_WITH_ROCM, IS_WINDOWS, slowTest, - set_default_dtype + set_default_dtype, + gradcheck ) + +from torch.testing._internal.common_methods_invocations import wrapper_set_seed from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater if TEST_FAIRSEQ: @@ -860,11 +863,22 @@ def rand_tensor(*shape): actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) - # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. - # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. - if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) + if attn_mask_dim is None: + q = q.double().clone() + k = k.double().clone() + v = v.double().clone() + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + + assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') @torch.no_grad() def test_mask_check_fastpath(self): @@ -1079,6 +1093,28 @@ def rand_tensor(shape): self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): + + batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 + query, key, value = torch.rand((batch_size, seq_len, 3 * num_heads * head_dim), + device="cuda", dtype=torch.float32, requires_grad=True).chunk(3, -1) + query = query.view(batch_size, -1, num_heads, head_dim) + key = key.view(batch_size, -1, num_heads, head_dim) + value = value.view(batch_size, -1, num_heads, head_dim) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Normally we would transpose the inputs but the fused kernels expect + # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel + # in fp32 + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), + (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8349a308be35..a0892b32a835 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2591,7 +2591,7 @@ - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor self: grad.reshape_symint(self.sym_sizes()) -# Nested Tensor +# NestedTensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" @@ -2612,6 +2612,11 @@ nested_size: non_differentiable nested_strides: non_differentiable +# Transformers +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) + # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 001fd455e82e..3b43b8fb4863 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11944,8 +11944,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), OpInfo( 'nn.functional._scaled_dot_product_attention', - op=lambda inp, *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, inp, *args, **kwargs), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), sample_inputs_func=sample_inputs_scaled_dot_product_attention, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),