diff --git a/test/test_jit.py b/test/test_jit.py index 13c27b0efa55..0e9e2f179256 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15059,6 +15059,67 @@ def jit_multihead_attn_forward(query, # type: Tensor # print(jit_out / py_out - 1) self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) + def test_functional_multi_head_attn_fast_path(self): + src_l = 3 + bsz = 5 + embed_size = 8 + nhead = 2 + query = key = value = torch.rand((src_l, bsz, embed_size)) + multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) + multi_head_attn_nn = multi_head_attn_nn.eval() + + with torch.no_grad(): + fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, + embed_size, nhead, + multi_head_attn_nn.in_proj_weight, + multi_head_attn_nn.in_proj_bias, + multi_head_attn_nn.bias_k, + multi_head_attn_nn.bias_v, + multi_head_attn_nn.add_zero_attn, + multi_head_attn_nn.dropout, + multi_head_attn_nn.out_proj.weight, + None, + training=False)[0] + + query_fb = key_fb = value_fb = query.transpose(1, 0) + py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0) + mha = torch.jit.script(multi_head_attn_nn) + jit_out = mha(query_fb, key_fb, value_fb)[0].transpose(1,0) + torch.testing.assert_close(fn_out, py_out) + torch.testing.assert_close(fn_out, jit_out) + + def test_functional_multi_head_attn_fast_path_None(self): + random_seed = 1 # or any of your favorite number + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + + np.random.seed(random_seed) + + src_l = 3 + bsz = 5 + embed_size = 8 + nhead = 2 + query = key = value = torch.rand((src_l, bsz, embed_size)) + multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) + multi_head_attn_nn = multi_head_attn_nn.eval() + + with torch.no_grad(): + fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, + embed_size, nhead, + multi_head_attn_nn.in_proj_weight, + multi_head_attn_nn.in_proj_bias, + multi_head_attn_nn.bias_k, + multi_head_attn_nn.bias_v, + multi_head_attn_nn.add_zero_attn, + multi_head_attn_nn.dropout, + multi_head_attn_nn.out_proj.weight, + None, + training=False)[0] + + query_fb = key_fb = value_fb = query.transpose(1, 0) + py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0) + self.assertEqual(fn_out.shape, py_out.shape) + def test_torchscript_multi_head_attn_fast_path(self): src_l = 3 bsz = 5 diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 961dd83f57b2..617bc2463216 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4864,6 +4864,10 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. # Shape check. + assert query.dim() == 3 or query.dim() == 2, \ + (f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") + + is_batched = False if query.dim() == 3: # Batched Inputs is_batched = True @@ -4898,9 +4902,6 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, expected_shape = (num_heads, query.shape[0], key.shape[0]) assert attn_mask.shape == expected_shape, \ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") - else: - raise AssertionError( - f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") return is_batched @@ -4992,6 +4993,237 @@ def multi_head_attention_forward( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ + + why_not_fast_path = _can_use_fastpath(query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training, + key_padding_mask, + attn_mask) + if not why_not_fast_path: + + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads) + assert in_proj_weight is not None + assert in_proj_bias is not None + assert out_proj_bias is not None + attn_output, attn_output_weights = torch._native_multi_head_attention( + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + merged_mask, + need_weights, + average_attn_weights, + mask_type) + return attn_output.transpose(1, 0), attn_output_weights + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}") + + return _multi_head_attention_forward_impl(query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight, + q_proj_weight, + k_proj_weight, + v_proj_weight, + static_k, + static_v, + average_attn_weights) + +def _merge_masks(attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], + query: Tensor, num_heads: int,) -> Tuple[Optional[Tensor], Optional[int]]: + r""" + Determine mask type and combine masks if necessary. If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + if attn_mask is not None: + mask_type = 0 + merged_mask = attn_mask + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + if (attn_mask is not None) and (key_padding_mask is not None): + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \ + .expand(-1, num_heads, -1, -1) + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, num_heads, -1, -1) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + return merged_mask, mask_type + +def _can_use_fastpath( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + +): + why_not_fast_path = '' + if torch.overrides.has_torch_function(( + query, + key, + value, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias,)): + why_not_fast_path = "some Tensor argument has_torch_function" + else: + is_batched = query.dim() == 3 + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif in_proj_bias is not None and query.dtype != in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({in_proj_bias.dtype}) don't match" + elif in_proj_weight is not None and query.dtype != in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({in_proj_weight.dtype}) don't match" + elif training: + why_not_fast_path = "training is enabled" + #elif not self.batch_first: + # why_not_fast_path = "batch_first was not True" + elif bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif dropout_p: + why_not_fast_path = f"dropout was {dropout_p}, required zero" + elif add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + #elif not self._qkv_same_embed_dim: + # why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + ) + + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif in_proj_weight is None: + why_not_fast_path = "in projection weight is None" + elif in_proj_bias is None: + why_not_fast_path = "in projection bias is None" + elif out_proj_bias is None: + why_not_fast_path = "out projecion bias is None" + elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): + why_not_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + return why_not_fast_path + +def _multi_head_attention_forward_impl( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Please, refer to multi_head_attention_forward documentation. Implementation of multihead_attention without fast path. + """ + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if has_torch_function(tens_ops): return handle_torch_function( diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index b00da06126a7..dc8a4fb2fb72 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1046,6 +1046,57 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, @@ -1066,59 +1117,29 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask): raise AssertionError( "only bool and floating types of key_padding_mask are supported") - why_not_fast_path = '' - if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ - is not supported with NestedTensor input" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - + why_not_fast_path = F._can_use_fastpath(query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + attn_mask) if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): - why_not_fast_path = ("grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") - if not why_not_fast_path: - merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query) - + if not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + else: + merged_mask, mask_type = F._merge_masks(attn_mask, key_padding_mask, query, self.num_heads) return torch._native_multi_head_attention( query, key, @@ -1150,7 +1171,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = F._multi_head_attention_forward_impl( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, @@ -1161,7 +1182,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) else: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = F._multi_head_attention_forward_impl( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, @@ -1174,39 +1195,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O else: return attn_output, attn_output_weights - def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], - query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: - r""" - Determine mask type and combine masks if necessary. If only one mask is provided, that mask - and the corresponding mask type will be returned. If both masks are provided, they will be both - expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` - and mask type 2 will be returned - Args: - attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 - key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 - query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` - Returns: - merged_mask: merged mask - mask_type: merged mask type (0, 1, or 2) - """ - mask_type: Optional[int] = None - merged_mask: Optional[Tensor] = None - if attn_mask is not None: - mask_type = 0 - merged_mask = attn_mask - if key_padding_mask is not None: - mask_type = 1 - merged_mask = key_padding_mask - if (attn_mask is not None) and (key_padding_mask is not None): - # In this branch query can't be a nested tensor, so it has a shape - batch_size, seq_len, _ = query.shape - mask_type = 2 - key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \ - .expand(-1, self.num_heads, -1, -1) - attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) - merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) - return merged_mask, mask_type - class PReLU(Module): r"""Applies the element-wise function: diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 5f1bc7bb2785..476bac0e292a 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -503,7 +503,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: - merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = F._merge_masks(src_mask, src_key_padding_mask, src, self.self_attn.num_heads) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim,