Skip to content

Commit

Permalink
[PT1.13 cherry pick]Fix Transformer's issue when padding is long type…
Browse files Browse the repository at this point in the history
… & unit test (pytorch#87106)

Summary:
Pull Request resolved: pytorch#87106

Pull Request resolved: pytorch#86353

Fix the issue described in
pytorch#86120

Test Plan: buck test mode/opt caffe2/test:test_transformers -- test_train_with_long_type_pad

Differential Revision: D40129968

fbshipit-source-id: 5fdfe5742d30344d12bf0faf11a0e93c12b8be76
  • Loading branch information
zrphercule authored and facebook-github-bot committed Oct 17, 2022
1 parent 3193151 commit 9c54c65
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 23 deletions.
18 changes: 4 additions & 14 deletions test/test_nn.py
Expand Up @@ -5399,7 +5399,7 @@ def _create_src_lengths_mask(batch_size, src_lengths):
return (src_indices < src_lengths).int().detach()

def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False,
saved_kv=False, same_embed_dim=False, byte_mask=False,
saved_kv=False, same_embed_dim=False,
average_attn_weights=average_attn_weights):
for _ in range(100):
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
Expand Down Expand Up @@ -5428,20 +5428,15 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
seq_mask = np.random.randint(0, 2, (1, seq_len))
key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1)
key_padding_mask_tensor = torch.from_numpy(key_padding_mask)
if byte_mask:
key_padding_mask_tensor = key_padding_mask_tensor.byte()
decoder_state = np.random.rand(batch_sz, d_model)
K = np.random.rand(*dims)
V = K
Q = np.expand_dims(decoder_state, 1)
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
attn_mask_tensor = torch.from_numpy(attn_mask).float()
if byte_mask:
attn_mask_tensor = (attn_mask_tensor == 0).byte()
else:
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()

decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype())
source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1)
Expand Down Expand Up @@ -5588,10 +5583,6 @@ def test_multihead_attn_all_arguments3():
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
saved_kv=True, same_embed_dim=True)

def test_multihead_attn_all_arguments4():
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
saved_kv=True, same_embed_dim=True, byte_mask=True)

test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
Expand All @@ -5602,7 +5593,6 @@ def test_multihead_attn_all_arguments4():
with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."):
test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments4() # Test MultiheadAttention with all the argument.

def test_multihead_attn_3d_attn_mask(self):
embed_dim = 8
Expand Down
54 changes: 54 additions & 0 deletions test/test_transformers.py
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import patch
import math
from torch.backends.cuda import sdp_kernel
import torch.optim as optim

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -67,6 +68,59 @@ def test_self_attn_TxT_attn_mask(self):

self.assertEqual(output_mask_4d, output_mask_TxT)

@parametrize("device", device_list)
def test_train_with_pad_and_catch_error(self, device):
iters = 100
pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
layer = nn.TransformerEncoderLayer(
d_model=2,
dim_feedforward=4,
nhead=2,
batch_first=True,
activation="gelu",
dropout=0,
)
criterion = nn.MSELoss()
encoder = nn.TransformerEncoder(layer, 2).to(device)
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
encoder.train()
for i in range(iters):
encoder.train()
optimizer.zero_grad()
inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)

outputs = encoder(inputs, src_key_padding_mask=pad_mask)

loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
loss.backward()
optimizer.step()

with torch.no_grad():
test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)

# Expect uint8 type not supported
ex = None
try:
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")

test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
encoder.eval()

# Expect long type not supported
ex = None
try:
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported Long type exception")

test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")

@parametrize("device", device_list)
@parametrize("nhead", [1, 4, 8])
def test_transformerencoderlayer_src_mask(self, device, nhead):
Expand Down
14 changes: 7 additions & 7 deletions torch/nn/functional.py
Expand Up @@ -4962,8 +4962,8 @@ def multi_head_attention_forward(
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
If a FloatTensor is provided, it will be directly added to the value.
If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
Expand Down Expand Up @@ -5033,6 +5033,11 @@ def multi_head_attention_forward(
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
if key_padding_mask is not None:
if key_padding_mask.dtype != torch.bool:
if not torch.is_floating_point(key_padding_mask.dtype):
raise AssertionError(
"only bool and floating type of key_padding_mask is supported")
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
Expand Down Expand Up @@ -5085,11 +5090,6 @@ def multi_head_attention_forward(
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

# prep key padding mask
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)

# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
Expand Down
8 changes: 6 additions & 2 deletions torch/nn/modules/activation.py
Expand Up @@ -1031,8 +1031,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
value will be ignored.
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
Expand Down Expand Up @@ -1062,6 +1061,11 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
if key_padding_mask.dtype != torch.bool:
if not torch.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating type of key_padding_mask is supported")
why_not_fast_path = ''
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
Expand Down
10 changes: 10 additions & 0 deletions torch/nn/modules/transformer.py
Expand Up @@ -209,6 +209,11 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
src_key_padding_mask_for_layers = src_key_padding_mask
why_not_sparsity_fast_path = ''
str_first_layer = "self.layers[0]"
if src_key_padding_mask is not None:
if src_key_padding_mask.dtype != torch.bool:
if not torch.is_floating_point(src_key_padding_mask.dtype):
raise AssertionError(
"only bool and floating type of key_padding_mask is supported")
if not isinstance(first_layer, torch.nn.TransformerEncoderLayer):
why_not_sparsity_fast_path = f"{str_first_layer} was not TransformerEncoderLayer"
elif first_layer.norm_first :
Expand Down Expand Up @@ -443,6 +448,11 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
"""

# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
if src_key_padding_mask is not None:
if src_key_padding_mask.dtype != torch.bool:
if not torch.is_floating_point(src_key_padding_mask.dtype):
raise AssertionError(
"only bool and floating type of key_padding_mask is supported")
why_not_sparsity_fast_path = ''
if not src.dim() == 3:
why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
Expand Down

0 comments on commit 9c54c65

Please sign in to comment.