From 1a89581acf99c8a7121cdd7bc7966157002d8cf8 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Fri, 11 Nov 2022 22:39:41 +0100 Subject: [PATCH 01/12] add fastpath run in functional multiheadattn --- test/test_jit.py | 61 +++++++++ torch/nn/functional.py | 238 ++++++++++++++++++++++++++++++++- torch/nn/modules/activation.py | 162 +++++++++++----------- 3 files changed, 371 insertions(+), 90 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 13c27b0efa55..23556f9bded8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15059,6 +15059,67 @@ def jit_multihead_attn_forward(query, # type: Tensor # print(jit_out / py_out - 1) self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) + def test_functional_multi_head_attn_fast_path(self): + src_l = 3 + bsz = 5 + embed_size = 8 + nhead = 2 + query = key = value = torch.rand((src_l, bsz, embed_size)) + multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) + multi_head_attn_nn = multi_head_attn_nn.eval() + + with torch.no_grad(): + fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, + embed_size, nhead, + multi_head_attn_nn.in_proj_weight, + multi_head_attn_nn.in_proj_bias, + multi_head_attn_nn.bias_k, + multi_head_attn_nn.bias_v, + multi_head_attn_nn.add_zero_attn, + multi_head_attn_nn.dropout, + multi_head_attn_nn.out_proj.weight, + None, + training=False)[0] + + query_fb = key_fb = value_fb = query.transpose(1, 0) + py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0) + mha = torch.jit.script(multi_head_attn_nn) + jit_out = mha(query_fb, key_fb, value_fb)[0].transpose(1,0) + torch.testing.assert_close(fn_out, py_out) + torch.testing.assert_close(fn_out, jit_out) + + def test_functional_multi_head_attn_fast_path_None(self): + random_seed = 1 # or any of your favorite number + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + + np.random.seed(random_seed) + + src_l = 3 + bsz = 5 + embed_size = 8 + nhead = 2 + query = key = value = torch.rand((src_l, bsz, embed_size)) + multi_head_attn_nn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) + multi_head_attn_nn = multi_head_attn_nn.eval() + + with torch.no_grad(): + fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, + embed_size, nhead, + None, + None, + multi_head_attn_nn.bias_k, + multi_head_attn_nn.bias_v, + multi_head_attn_nn.add_zero_attn, + multi_head_attn_nn.dropout, + multi_head_attn_nn.out_proj.weight, + None, + training=False)[0] + + query_fb = key_fb = value_fb = query.transpose(1, 0) + py_out = multi_head_attn_nn(query_fb, key_fb, value_fb)[0].transpose(1,0) + self.assertEqual(fn_out.shape, py_out.shape) + def test_torchscript_multi_head_attn_fast_path(self): src_l = 3 bsz = 5 diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 961dd83f57b2..811d9529e752 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4864,6 +4864,10 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. # Shape check. + assert query.dim() == 3 or query.dim() == 2, \ + (f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") + + is_batched = False if query.dim() == 3: # Batched Inputs is_batched = True @@ -4898,9 +4902,6 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, expected_shape = (num_heads, query.shape[0], key.shape[0]) assert attn_mask.shape == expected_shape, \ (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") - else: - raise AssertionError( - f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") return is_batched @@ -4992,6 +4993,237 @@ def multi_head_attention_forward( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ + with torch.no_grad(): + in_proj_weight_ = torch.nn.init.xavier_uniform_(torch.empty((3 * embed_dim_to_check, embed_dim_to_check),\ + device=query.device, dtype=query.dtype))\ + if in_proj_weight is None or in_proj_weight.dim() != 2 else in_proj_weight + with torch.no_grad(): + in_proj_bias_ = torch.zeros(3 * embed_dim_to_check, device=query.device, dtype=query.dtype) \ + if in_proj_bias is None else in_proj_bias + + with torch.no_grad(): + out_proj_bias_ = torch.zeros(embed_dim_to_check, device=query.device, dtype=query.dtype) \ + if out_proj_bias is None else out_proj_bias + why_not_fast_path = _can_use_fastpath(query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight_, + in_proj_bias_, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias_, + training, + key_padding_mask, + attn_mask) + if not why_not_fast_path: + + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads) + attn_output, attn_output_weights = torch._native_multi_head_attention( + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight_, + in_proj_bias_, + out_proj_weight, + out_proj_bias_, + merged_mask, + need_weights, + average_attn_weights, + mask_type) + return attn_output.transpose(1, 0), attn_output_weights + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}") + + return _multi_head_attention_forward_impl(query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight, + q_proj_weight, + k_proj_weight, + v_proj_weight, + static_k, + static_v, + average_attn_weights) + +def _merge_masks(attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], + query: Tensor, num_heads: int,) -> Tuple[Optional[Tensor], Optional[int]]: + r""" + Determine mask type and combine masks if necessary. If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + if attn_mask is not None: + mask_type = 0 + merged_mask = attn_mask + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + if (attn_mask is not None) and (key_padding_mask is not None): + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \ + .expand(-1, num_heads, -1, -1) + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, num_heads, -1, -1) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + return merged_mask, mask_type + +def _can_use_fastpath( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + +): + why_not_fast_path = '' + if torch.overrides.has_torch_function(( + query, + key, + value, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias,)): + why_not_fast_path = "some Tensor argument has_torch_function" + else: + is_batched = query.dim() == 3 + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif in_proj_bias is not None and query.dtype != in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({in_proj_bias.dtype}) don't match" + elif in_proj_weight is not None and query.dtype != in_proj_weight.dtype: + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({in_proj_weight.dtype}) don't match" + elif training: + why_not_fast_path = "training is enabled" + #elif not self.batch_first: + # why_not_fast_path = "batch_first was not True" + elif bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif dropout_p: + why_not_fast_path = f"dropout was {dropout_p}, required zero" + elif add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + #elif not self._qkv_same_embed_dim: + # why_not_fast_path = "_qkv_same_embed_dim was not True" + elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + why_not_fast_path = ("grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad") + return why_not_fast_path + +def _multi_head_attention_forward_impl( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Please, refer to multi_head_attention_forward documentation. Implementation of multihead_attention without fast path. + """ + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) if has_torch_function(tens_ops): return handle_torch_function( diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 7b0e7e3effaa..fc145bacce55 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1046,6 +1046,57 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + Outputs: - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, @@ -1066,59 +1117,29 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask): raise AssertionError( "only bool and floating types of key_padding_mask are supported") - why_not_fast_path = '' - if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ - is not supported with NestedTensor input" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - + why_not_fast_path = F._can_use_fastpath(query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + self.training, + key_padding_mask, + attn_mask) if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): - why_not_fast_path = ("grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") - if not why_not_fast_path: - merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query) - + if not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + else: + merged_mask, mask_type = F._merge_masks(attn_mask, key_padding_mask, query, self.num_heads) return torch._native_multi_head_attention( query, key, @@ -1150,7 +1171,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = F._multi_head_attention_forward_impl( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, @@ -1161,7 +1182,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) else: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = F._multi_head_attention_forward_impl( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, @@ -1174,39 +1195,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O else: return attn_output, attn_output_weights - def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], - query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: - r""" - Determine mask type and combine masks if necessary. If only one mask is provided, that mask - and the corresponding mask type will be returned. If both masks are provided, they will be both - expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` - and mask type 2 will be returned - Args: - attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 - key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 - query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` - Returns: - merged_mask: merged mask - mask_type: merged mask type (0, 1, or 2) - """ - mask_type: Optional[int] = None - merged_mask: Optional[Tensor] = None - if attn_mask is not None: - mask_type = 0 - merged_mask = attn_mask - if key_padding_mask is not None: - mask_type = 1 - merged_mask = key_padding_mask - if (attn_mask is not None) and (key_padding_mask is not None): - # In this branch query can't be a nested tensor, so it has a shape - batch_size, seq_len, _ = query.shape - mask_type = 2 - key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \ - .expand(-1, self.num_heads, -1, -1) - attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) - merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) - return merged_mask, mask_type - class PReLU(Module): r"""Applies the element-wise function: From 2d44588167c99ae58c95863a10476438bcdcebae Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 17:46:27 +0100 Subject: [PATCH 02/12] add null check in_proj, out_proj --- test/test_jit.py | 4 ++-- torch/nn/functional.py | 41 +++++++++++++++------------------- torch/nn/modules/activation.py | 2 +- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 23556f9bded8..0e9e2f179256 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15106,8 +15106,8 @@ def test_functional_multi_head_attn_fast_path_None(self): with torch.no_grad(): fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, embed_size, nhead, - None, - None, + multi_head_attn_nn.in_proj_weight, + multi_head_attn_nn.in_proj_bias, multi_head_attn_nn.bias_k, multi_head_attn_nn.bias_v, multi_head_attn_nn.add_zero_attn, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 811d9529e752..5b4319f78d6a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2524,8 +2524,6 @@ def group_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) - if input.dim() < 2: - raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) @@ -4993,30 +4991,20 @@ def multi_head_attention_forward( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ - with torch.no_grad(): - in_proj_weight_ = torch.nn.init.xavier_uniform_(torch.empty((3 * embed_dim_to_check, embed_dim_to_check),\ - device=query.device, dtype=query.dtype))\ - if in_proj_weight is None or in_proj_weight.dim() != 2 else in_proj_weight - with torch.no_grad(): - in_proj_bias_ = torch.zeros(3 * embed_dim_to_check, device=query.device, dtype=query.dtype) \ - if in_proj_bias is None else in_proj_bias - with torch.no_grad(): - out_proj_bias_ = torch.zeros(embed_dim_to_check, device=query.device, dtype=query.dtype) \ - if out_proj_bias is None else out_proj_bias why_not_fast_path = _can_use_fastpath(query, key, value, embed_dim_to_check, num_heads, - in_proj_weight_, - in_proj_bias_, + in_proj_weight, + in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias_, + out_proj_bias, training, key_padding_mask, attn_mask) @@ -5038,10 +5026,10 @@ def multi_head_attention_forward( value, embed_dim_to_check, num_heads, - in_proj_weight_, - in_proj_bias_, + in_proj_weight, + in_proj_bias, out_proj_weight, - out_proj_bias_, + out_proj_bias, merged_mask, need_weights, average_attn_weights, @@ -5116,14 +5104,14 @@ def _can_use_fastpath( value: Tensor, embed_dim_to_check: int, num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, - out_proj_bias: Tensor, + out_proj_bias: Optional[Tensor], training: bool = True, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, @@ -5183,13 +5171,20 @@ def _can_use_fastpath( out_proj_weight, out_proj_bias, ) + # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif in_proj_weight is None: + why_not_fast_path = "in projection weight is None" + elif in_proj_bias is None: + why_not_fast_path = "in projection bias is None" + elif out_proj_bias is None: + why_not_fast_path = "out projecion bias is None" + elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") return why_not_fast_path diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index fc145bacce55..dc8a4fb2fb72 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1060,7 +1060,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O .. note:: `batch_first` argument is ignored for unbatched inputs. """ - def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: r""" From b55b16d21acad46a56b2c420b17f9a2b98e93262 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 18:40:30 +0100 Subject: [PATCH 03/12] use merge masks from functional --- torch/nn/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 37e8823edf2c..6d70c51772f9 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -500,7 +500,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: - merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = F.merge_masks(src_mask, src_key_padding_mask, src) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, From b42f855b7099897d145c29740405ec68f614a9ab Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 17:46:27 +0100 Subject: [PATCH 04/12] add null check in_proj, out_proj --- test/test_jit.py | 4 ++-- torch/nn/functional.py | 41 +++++++++++++++------------------- torch/nn/modules/activation.py | 2 +- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 23556f9bded8..0e9e2f179256 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15106,8 +15106,8 @@ def test_functional_multi_head_attn_fast_path_None(self): with torch.no_grad(): fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value, embed_size, nhead, - None, - None, + multi_head_attn_nn.in_proj_weight, + multi_head_attn_nn.in_proj_bias, multi_head_attn_nn.bias_k, multi_head_attn_nn.bias_v, multi_head_attn_nn.add_zero_attn, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 811d9529e752..5b4319f78d6a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2524,8 +2524,6 @@ def group_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) - if input.dim() < 2: - raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) @@ -4993,30 +4991,20 @@ def multi_head_attention_forward( :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. """ - with torch.no_grad(): - in_proj_weight_ = torch.nn.init.xavier_uniform_(torch.empty((3 * embed_dim_to_check, embed_dim_to_check),\ - device=query.device, dtype=query.dtype))\ - if in_proj_weight is None or in_proj_weight.dim() != 2 else in_proj_weight - with torch.no_grad(): - in_proj_bias_ = torch.zeros(3 * embed_dim_to_check, device=query.device, dtype=query.dtype) \ - if in_proj_bias is None else in_proj_bias - with torch.no_grad(): - out_proj_bias_ = torch.zeros(embed_dim_to_check, device=query.device, dtype=query.dtype) \ - if out_proj_bias is None else out_proj_bias why_not_fast_path = _can_use_fastpath(query, key, value, embed_dim_to_check, num_heads, - in_proj_weight_, - in_proj_bias_, + in_proj_weight, + in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, - out_proj_bias_, + out_proj_bias, training, key_padding_mask, attn_mask) @@ -5038,10 +5026,10 @@ def multi_head_attention_forward( value, embed_dim_to_check, num_heads, - in_proj_weight_, - in_proj_bias_, + in_proj_weight, + in_proj_bias, out_proj_weight, - out_proj_bias_, + out_proj_bias, merged_mask, need_weights, average_attn_weights, @@ -5116,14 +5104,14 @@ def _can_use_fastpath( value: Tensor, embed_dim_to_check: int, num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], bias_k: Optional[Tensor], bias_v: Optional[Tensor], add_zero_attn: bool, dropout_p: float, out_proj_weight: Tensor, - out_proj_bias: Tensor, + out_proj_bias: Optional[Tensor], training: bool = True, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, @@ -5183,13 +5171,20 @@ def _can_use_fastpath( out_proj_weight, out_proj_bias, ) + # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif in_proj_weight is None: + why_not_fast_path = "in projection weight is None" + elif in_proj_bias is None: + why_not_fast_path = "in projection bias is None" + elif out_proj_bias is None: + why_not_fast_path = "out projecion bias is None" + elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") return why_not_fast_path diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index fc145bacce55..dc8a4fb2fb72 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1060,7 +1060,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O .. note:: `batch_first` argument is ignored for unbatched inputs. """ - def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: r""" From 8ee3da1d8743246a13782cadb38ad24f3ac63972 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 18:40:30 +0100 Subject: [PATCH 05/12] use merge masks from functional --- torch/nn/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 37e8823edf2c..6d70c51772f9 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -500,7 +500,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: - merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = F.merge_masks(src_mask, src_key_padding_mask, src) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, From ee378ee968232b143dc33b8e3421dbb1924d17d4 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 20:10:16 +0100 Subject: [PATCH 06/12] merge from master --- torch/nn/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 5b4319f78d6a..124b705dd74f 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2524,6 +2524,8 @@ def group_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) + if input.dim() < 2: + raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) From fb355b0e36db1dd1b86889d91ae6fa2408c4a291 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 21:34:43 +0100 Subject: [PATCH 07/12] _merge_masks --- torch/nn/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 6d70c51772f9..9cb67365935c 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -500,7 +500,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: - merged_mask, mask_type = F.merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = F._merge_masks(src_mask, src_key_padding_mask, src) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, From 00f1a76ca4dfa4a1545e1a54c710c26df5dc3248 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Mon, 14 Nov 2022 22:56:23 +0100 Subject: [PATCH 08/12] missing num_heads argument fix --- torch/nn/modules/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 9cb67365935c..5fc3882c0e48 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -500,7 +500,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: - merged_mask, mask_type = F._merge_masks(src_mask, src_key_padding_mask, src) + merged_mask, mask_type = F._merge_masks(src_mask, src_key_padding_mask, src, self.self_attn.num_heads) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, From d840579ebeb1e5d1628520c57563d86b4f9aee88 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Tue, 15 Nov 2022 09:35:27 +0100 Subject: [PATCH 09/12] cast optional to tensor --- torch/nn/functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 124b705dd74f..c10e050d9362 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -23,6 +23,7 @@ from . import grad # noqa: F401 from .modules import utils from .modules.utils import _single, _pair, _triple, _list_with_default +from typing import cast Tensor = torch.Tensor @@ -5028,10 +5029,10 @@ def multi_head_attention_forward( value, embed_dim_to_check, num_heads, - in_proj_weight, - in_proj_bias, + cast(torch.Tensor, in_proj_weight), + cast(torch.Tensor, in_proj_bias), out_proj_weight, - out_proj_bias, + cast(torch.Tensor, out_proj_bias), merged_mask, need_weights, average_attn_weights, From 06cb1a2b875ada68215172d42dfb9f10673188e7 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Tue, 15 Nov 2022 23:09:23 +0100 Subject: [PATCH 10/12] casting Optional as Tensor --- torch/nn/functional.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index c10e050d9362..020fa5ece0bf 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5023,16 +5023,19 @@ def multi_head_attention_forward( query, key, value = [x.transpose(1, 0) for x in (query, key, value)] merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads) + in_proj_weight_ = cast(torch.Tensor, in_proj_weight) + in_proj_bias_ = cast(torch.Tensor, in_proj_bias) + out_proj_bias_ = cast(torch.Tensor, out_proj_bias) attn_output, attn_output_weights = torch._native_multi_head_attention( query, key, value, embed_dim_to_check, num_heads, - cast(torch.Tensor, in_proj_weight), - cast(torch.Tensor, in_proj_bias), + in_proj_weight_, + in_proj_bias_, out_proj_weight, - cast(torch.Tensor, out_proj_bias), + out_proj_bias_, merged_mask, need_weights, average_attn_weights, From 9854d6bfa319e37b9b45dd6a23efabf2d86ae283 Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Wed, 16 Nov 2022 09:37:45 +0100 Subject: [PATCH 11/12] assertion Optional parameters not None --- torch/nn/functional.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 020fa5ece0bf..b7a5b72ff224 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5023,19 +5023,19 @@ def multi_head_attention_forward( query, key, value = [x.transpose(1, 0) for x in (query, key, value)] merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads) - in_proj_weight_ = cast(torch.Tensor, in_proj_weight) - in_proj_bias_ = cast(torch.Tensor, in_proj_bias) - out_proj_bias_ = cast(torch.Tensor, out_proj_bias) + assert in_proj_weight is not None + assert in_proj_bias is not None + assert out_proj_bias is not None attn_output, attn_output_weights = torch._native_multi_head_attention( query, key, value, embed_dim_to_check, num_heads, - in_proj_weight_, - in_proj_bias_, + in_proj_weight, + in_proj_bias, out_proj_weight, - out_proj_bias_, + out_proj_bias, merged_mask, need_weights, average_attn_weights, From 00b7ca793984c917c25018326906d72d1186f6fe Mon Sep 17 00:00:00 2001 From: Olga Gerasimova Date: Wed, 16 Nov 2022 10:54:58 +0100 Subject: [PATCH 12/12] removed cast import --- torch/nn/functional.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index b7a5b72ff224..617bc2463216 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -23,7 +23,6 @@ from . import grad # noqa: F401 from .modules import utils from .modules.utils import _single, _pair, _triple, _list_with_default -from typing import cast Tensor = torch.Tensor