diff --git a/test/test_nn.py b/test/test_nn.py index a110163aef49..b6891a818873 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5399,7 +5399,7 @@ def _create_src_lengths_mask(batch_size, src_lengths): return (src_indices < src_lengths).int().detach() def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False, - saved_kv=False, same_embed_dim=False, byte_mask=False, + saved_kv=False, same_embed_dim=False, average_attn_weights=average_attn_weights): for _ in range(100): batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] @@ -5428,20 +5428,15 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a seq_mask = np.random.randint(0, 2, (1, seq_len)) key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1) key_padding_mask_tensor = torch.from_numpy(key_padding_mask) - if byte_mask: - key_padding_mask_tensor = key_padding_mask_tensor.byte() decoder_state = np.random.rand(batch_sz, d_model) K = np.random.rand(*dims) V = K Q = np.expand_dims(decoder_state, 1) attn_mask = np.random.randint(0 , 2, size=(1, seq_len)) attn_mask_tensor = torch.from_numpy(attn_mask).float() - if byte_mask: - attn_mask_tensor = (attn_mask_tensor == 0).byte() - else: - attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf')) - attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0')) - attn_mask_tensor = attn_mask_tensor.double() + attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf')) + attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0')) + attn_mask_tensor = attn_mask_tensor.double() decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype()) source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1) @@ -5588,10 +5583,6 @@ def test_multihead_attn_all_arguments3(): _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, saved_kv=True, same_embed_dim=True) - def test_multihead_attn_all_arguments4(): - _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, - saved_kv=True, same_embed_dim=True, byte_mask=True) - test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv test_multihead_attn_no_masking() # Test MultiheadAttention without masking @@ -5602,7 +5593,6 @@ def test_multihead_attn_all_arguments4(): with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."): test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument. test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument. - test_multihead_attn_all_arguments4() # Test MultiheadAttention with all the argument. def test_multihead_attn_3d_attn_mask(self): embed_dim = 8 diff --git a/test/test_transformers.py b/test/test_transformers.py index bc70e531f0ef..1eff4d61fb20 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -8,6 +8,7 @@ from unittest.mock import patch import math from torch.backends.cuda import sdp_kernel +import torch.optim as optim from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( @@ -69,6 +70,59 @@ def test_self_attn_TxT_attn_mask(self): self.assertEqual(output_mask_4d, output_mask_TxT) + @parametrize("device", device_list) + def test_train_with_pad_and_catch_error(self, device): + iters = 100 + pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device) + layer = nn.TransformerEncoderLayer( + d_model=2, + dim_feedforward=4, + nhead=2, + batch_first=True, + activation="gelu", + dropout=0, + ) + criterion = nn.MSELoss() + encoder = nn.TransformerEncoder(layer, 2).to(device) + optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) + encoder.train() + for i in range(iters): + encoder.train() + optimizer.zero_grad() + inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) + + outputs = encoder(inputs, src_key_padding_mask=pad_mask) + + loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :]) + loss.backward() + optimizer.step() + + with torch.no_grad(): + test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) + + # Expect uint8 type not supported + ex = None + try: + test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8)) + except AssertionError as e: + continue + self.assertFalse(e, "Failed to catch unsupported uint8 type exception") + + test_train_bool = encoder(test, src_key_padding_mask=pad_mask) + encoder.eval() + + # Expect long type not supported + ex = None + try: + test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64)) + except AssertionError as e: + continue + self.assertFalse(e, "Failed to catch unsupported Long type exception") + + test_eval_bool = encoder(test, src_key_padding_mask=pad_mask) + l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() + self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") + @parametrize("device", device_list) @parametrize("nhead", [1, 4, 8]) def test_transformerencoderlayer_src_mask(self, device, nhead): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 47176d74b59f..0e4f58ccd64e 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4966,8 +4966,8 @@ def multi_head_attention_forward( - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, @@ -5037,6 +5037,11 @@ def multi_head_attention_forward( # set up shape vars tgt_len, bsz, embed_dim = query.shape src_len, _, _ = key.shape + if key_padding_mask is not None: + _skpm_dtype = key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(_skpm_dtype): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") assert embed_dim == embed_dim_to_check, \ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, torch.Tensor): @@ -5089,11 +5094,6 @@ def multi_head_attention_forward( else: raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") - # prep key padding mask - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") - key_padding_mask = key_padding_mask.to(torch.bool) - # add bias along batch dimension (currently second) if bias_k is not None and bias_v is not None: assert static_k is None, "bias cannot be added to static key." diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index bd6866aa3543..986aa4cf03d5 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1031,8 +1031,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` - value will be ignored. + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. Default: ``True``. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape @@ -1062,6 +1061,11 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O `batch_first` argument is ignored for unbatched inputs. """ is_batched = query.dim() == 3 + if key_padding_mask is not None: + _skpm_dtype = key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(_skpm_dtype): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") why_not_fast_path = '' if not is_batched: why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 240d5ff0eae1..9d5e646fefa9 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -203,6 +203,11 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma Shape: see the docs in Transformer class. """ + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(_skpm_dtype): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") output = src convert_to_nested = False first_layer = self.layers[0] @@ -442,6 +447,11 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, see the docs in Transformer class. """ + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point(_skpm_dtype): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf why_not_sparsity_fast_path = '' if not src.dim() == 3: