From e65e4ac1f12629ebf60ee14461089a1217a7bd34 Mon Sep 17 00:00:00 2001 From: erichan1 <30481032+erichan1@users.noreply.github.com> Date: Mon, 25 Jul 2022 00:54:24 -0700 Subject: [PATCH] 1.12.1/bt fix (#81952) * Add test for torchscripting nn.TransformerEncoder, including fast path (#79796) (#79796) Summary: Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training]. Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail. This test goes together with this fix https://github.com/pytorch/pytorch/pull/79480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79796 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/06274d7a487bf7995da77b9df9b5c1f7dc13f35b Test plan from GitHub: ``` buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace //caffe2/test:transformers ./fbcode/buck-out/gen/caffe2/test/transformers#binary.par ``` Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix. Reviewed By: mikekgfb Differential Revision: D37222923 Pulled By: erichan1 fbshipit-source-id: 5a16e7d240cb51c0a613d16a79931d41122aba8b * disable src mask for transformer and multiheadattention fastpath (#81277) (#81277) Summary: Disable fastpath if src_mask passed to TransformerEncoderLayer and MultiheadAttention. - Refactored test_transformerencoder from test_nn.py to test_transformers.py. Added a src_mask test there. - Added a specific src_mask test in test_transformers.py Fixes https://github.com/pytorch/pytorch/issues/81129 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81277 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/23088fcfdf77632d4e6db4d35ce62735ca6622d2 Reviewed By: DanilBaibak Differential Revision: D37919513 Pulled By: erichan1 fbshipit-source-id: 0697d789634775136897fdb6a310356a6a45030d * remove decoder tests for feature not in 1.12 * remove unnecessary changes from #77903 to make changes more minimal --- .../ATen/native/nested/NestedTensorMath.cpp | 12 +- test/test_nn.py | 26 --- test/test_transformers.py | 156 ++++++++++++++++++ torch/nn/modules/activation.py | 8 +- torch/nn/modules/transformer.py | 7 +- 5 files changed, 174 insertions(+), 35 deletions(-) create mode 100644 test/test_transformers.py diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 4824c0329749..ec51187f28bd 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -316,7 +316,17 @@ Tensor nested_from_padded_generic( padded.size(2), padded.size(1) * padded.size(3)}); } - const auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes); + auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes); + // There may be extra padding on padded beyond the max size in the nested tensor. + // Make the mask size match. + const size_t dim = padded_transformed.dim(); + TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size()); + for (size_t ii = 0; ii < dim - 1; ++ii) { + const auto padded_size_i = padded_transformed.sizes()[ii + 1]; + if (target_size[ii] < padded_size_i) { + target_size[ii] = padded_size_i; + } + } IntArrayRef target_size_arr(target_size); std::vector masks; std::vector all_sizes = sizes.unbind(); diff --git a/test/test_nn.py b/test/test_nn.py index 7759a4ed013e..aad884ebd4f2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5826,32 +5826,6 @@ def test_multihead_attn_3d_attn_mask(self): # output_2d in shape of [T, 1, D] self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_self_attn_TxT_attn_mask(self): - embed_dim = 16 - num_heads = 4 - batch_size = 10 - tgt_len = 16 - - query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] - attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] - attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) - - attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) - - mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() - mta_model.eval() - - # Generate 3D results - with torch.inference_mode(): - output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] - output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] - - output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] - output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] - - self.assertEqual(output_mask_4d, output_mask_TxT) - def test_multihead_attn_no_bias(self): embed_dim = 8 num_heads = 4 diff --git a/test/test_transformers.py b/test/test_transformers.py new file mode 100644 index 000000000000..57fa79844d06 --- /dev/null +++ b/test/test_transformers.py @@ -0,0 +1,156 @@ +# Owner(s): ["module: nn"] + +import contextlib +import torch +import unittest + +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests +from torch.testing._internal.common_cuda import TEST_CUDA + +@contextlib.contextmanager +def set_default_dtype(dtype): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(saved_dtype) + +class TestTransformers(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + device_list = ['cpu'] # TODO: is there a way to do parametrize for this? + if TEST_CUDA: + device_list.append('cuda') + + @unittest.skip("4D mask not supported yet - activate when 4D mask supported") + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu + def test_self_attn_TxT_attn_mask(self): + embed_dim = 16 + num_heads = 4 + batch_size = 10 + tgt_len = 16 + + query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] + attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] + attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) + + attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) + + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() + mta_model.eval() + + # Generate 3D results + with torch.inference_mode(): + output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] + output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] + + output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] + output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] + + self.assertEqual(output_mask_4d, output_mask_TxT) + + @parametrize("device", device_list) + def test_transformerencoderlayer_src_mask(self, device): + batch_size = 2 + seqlen = 4 + d_model = 8 + nhead = 8 + dim_feedforward = 32 + + model = torch.nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True).to(device) + src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model + src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) + + model(src, src_mask=src_mask) + model.eval() + with torch.no_grad(): + model(src, src_mask=src_mask) + + @parametrize("use_torchscript", [True, False]) + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): + """ + Test TransformerEncoder does not crash + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), + num_layers=2, + enable_nested_tensor=True + ) + + if training: + model = model.train() + else: + model = model.eval() + + if use_torchscript: + model = torch.jit.script(model) + + x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) + mask = torch.Tensor([[0, 1]]).to(torch.bool) + + if with_no_grad: + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + model(x, src_key_padding_mask=mask) + + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + @parametrize("enable_nested_tensor", [False]) + @parametrize("device", device_list) + def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): + """ + Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has + batch size == sequence length + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), + num_layers=2, + enable_nested_tensor=enable_nested_tensor + ).to(device) + + with torch.no_grad(): + # set constant weights of the model + for idx, p in enumerate(model.parameters()): + x = p.data + sz = x.view(-1).size(0) + shape = x.shape + x = torch.cos(torch.arange(0, sz).float().view(shape)) + p.data.copy_(x) + + if training: + model = model.train() + else: + model = model.eval() + x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device) + src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) + + if with_no_grad: + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + result = model(x, mask=src_mask) + + ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], + [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], + [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], + [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] + ).to(device) + self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) + torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) + +instantiate_parametrized_tests(TestTransformers) + +if __name__ == '__main__': + run_tests() diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 27a73579fd35..b19cce10d53e 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1085,10 +1085,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" - elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input" - elif not query.is_nested and key_padding_mask is not None and attn_mask is not None: - why_not_fast_path = "key_padding_mask and attn_mask were both supplied" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" if not why_not_fast_path: tensor_args = ( diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index b867b7b4af2c..cbac8e38cc19 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -411,9 +411,8 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.self_attn.batch_first and self.self_attn._qkv_same_embed_dim and self.activation_relu_or_gelu and self.norm1.eps == self.norm2.eps and - ((src_mask is None and src_key_padding_mask is None) - if src.is_nested - else (src_mask is None or src_key_padding_mask is None))): + src_mask is None and + not (src.is_nested and src_key_padding_mask is not None)): tensor_args = ( src, self.self_attn.in_proj_weight, @@ -453,7 +452,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.linear1.bias, self.linear2.weight, self.linear2.bias, - src_mask if src_mask is not None else src_key_padding_mask, + src_mask if src_mask is not None else src_key_padding_mask, # TODO: split into two args ) x = src if self.norm_first: