diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 4824c0329749..ec51187f28bd 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -316,7 +316,17 @@ Tensor nested_from_padded_generic( padded.size(2), padded.size(1) * padded.size(3)}); } - const auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes); + auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes); + // There may be extra padding on padded beyond the max size in the nested tensor. + // Make the mask size match. + const size_t dim = padded_transformed.dim(); + TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size()); + for (size_t ii = 0; ii < dim - 1; ++ii) { + const auto padded_size_i = padded_transformed.sizes()[ii + 1]; + if (target_size[ii] < padded_size_i) { + target_size[ii] = padded_size_i; + } + } IntArrayRef target_size_arr(target_size); std::vector masks; std::vector all_sizes = sizes.unbind(); diff --git a/test/test_nn.py b/test/test_nn.py index 7759a4ed013e..aad884ebd4f2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5826,32 +5826,6 @@ def test_multihead_attn_3d_attn_mask(self): # output_2d in shape of [T, 1, D] self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_self_attn_TxT_attn_mask(self): - embed_dim = 16 - num_heads = 4 - batch_size = 10 - tgt_len = 16 - - query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] - attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] - attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) - - attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) - - mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() - mta_model.eval() - - # Generate 3D results - with torch.inference_mode(): - output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] - output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] - - output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] - output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] - - self.assertEqual(output_mask_4d, output_mask_TxT) - def test_multihead_attn_no_bias(self): embed_dim = 8 num_heads = 4 diff --git a/test/test_transformers.py b/test/test_transformers.py new file mode 100644 index 000000000000..57fa79844d06 --- /dev/null +++ b/test/test_transformers.py @@ -0,0 +1,156 @@ +# Owner(s): ["module: nn"] + +import contextlib +import torch +import unittest + +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from torch.testing._internal.common_cuda import TEST_CUDA + +@contextlib.contextmanager +def set_default_dtype(dtype): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(saved_dtype) + +class TestTransformers(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + device_list = ['cpu'] # TODO: is there a way to do parametrize for this? + if TEST_CUDA: + device_list.append('cuda') + + @unittest.skip("4D mask not supported yet - activate when 4D mask supported") + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu + def test_self_attn_TxT_attn_mask(self): + embed_dim = 16 + num_heads = 4 + batch_size = 10 + tgt_len = 16 + + query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] + attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] + attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) + + attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) + + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() + mta_model.eval() + + # Generate 3D results + with torch.inference_mode(): + output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] + output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] + + output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] + output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] + + self.assertEqual(output_mask_4d, output_mask_TxT) + + @parametrize("device", device_list) + def test_transformerencoderlayer_src_mask(self, device): + batch_size = 2 + seqlen = 4 + d_model = 8 + nhead = 8 + dim_feedforward = 32 + + model = torch.nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True).to(device) + src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model + src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) + + model(src, src_mask=src_mask) + model.eval() + with torch.no_grad(): + model(src, src_mask=src_mask) + + @parametrize("use_torchscript", [True, False]) + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): + """ + Test TransformerEncoder does not crash + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), + num_layers=2, + enable_nested_tensor=True + ) + + if training: + model = model.train() + else: + model = model.eval() + + if use_torchscript: + model = torch.jit.script(model) + + x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) + mask = torch.Tensor([[0, 1]]).to(torch.bool) + + if with_no_grad: + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + model(x, src_key_padding_mask=mask) + + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + @parametrize("enable_nested_tensor", [False]) + @parametrize("device", device_list) + def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): + """ + Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has + batch size == sequence length + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), + num_layers=2, + enable_nested_tensor=enable_nested_tensor + ).to(device) + + with torch.no_grad(): + # set constant weights of the model + for idx, p in enumerate(model.parameters()): + x = p.data + sz = x.view(-1).size(0) + shape = x.shape + x = torch.cos(torch.arange(0, sz).float().view(shape)) + p.data.copy_(x) + + if training: + model = model.train() + else: + model = model.eval() + x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device) + src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) + + if with_no_grad: + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + result = model(x, mask=src_mask) + + ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], + [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], + [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], + [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] + ).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + +instantiate_parametrized_tests(TestTransformers) + +if __name__ == '__main__': + run_tests() diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 27a73579fd35..b19cce10d53e 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1085,10 +1085,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input" - elif not query.is_nested and key_padding_mask is not None and attn_mask is not None: - why_not_fast_path = "key_padding_mask and attn_mask were both supplied" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" if not why_not_fast_path: tensor_args = ( diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index b867b7b4af2c..cbac8e38cc19 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -411,9 +411,8 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.self_attn.batch_first and self.self_attn._qkv_same_embed_dim and self.activation_relu_or_gelu and self.norm1.eps == self.norm2.eps and - ((src_mask is None and src_key_padding_mask is None) - if src.is_nested - else (src_mask is None or src_key_padding_mask is None))): + src_mask is None and + not (src.is_nested and src_key_padding_mask is not None)): tensor_args = ( src, self.self_attn.in_proj_weight, @@ -453,7 +452,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.linear1.bias, self.linear2.weight, self.linear2.bias, - src_mask if src_mask is not None else src_key_padding_mask, + src_mask if src_mask is not None else src_key_padding_mask, # TODO: split into two args ) x = src if self.norm_first: