From 26cb85088c2e1b830a0847f62b5acef2bb99970a Mon Sep 17 00:00:00 2001 From: "Eric Han (Meta Employee)" Date: Mon, 20 Jun 2022 16:52:17 -0700 Subject: [PATCH] Add test for torchscripting nn.TransformerEncoder, including fast path (#79796) (#79796) Summary: Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training]. Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail. This test goes together with this fix https://github.com/pytorch/pytorch/pull/79480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79796 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/06274d7a487bf7995da77b9df9b5c1f7dc13f35b Test plan from GitHub: ``` buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace //caffe2/test:transformers ./fbcode/buck-out/gen/caffe2/test/transformers#binary.par ``` Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix. Reviewed By: mikekgfb Differential Revision: D37222923 Pulled By: erichan1 fbshipit-source-id: 5a16e7d240cb51c0a613d16a79931d41122aba8b --- test/test_transformers.py | 349 ++++++++++++++++++++++++ torch/testing/_internal/common_utils.py | 1 + 2 files changed, 350 insertions(+) create mode 100644 test/test_transformers.py diff --git a/test/test_transformers.py b/test/test_transformers.py new file mode 100644 index 000000000000..19670f418313 --- /dev/null +++ b/test/test_transformers.py @@ -0,0 +1,349 @@ +# Owner(s): ["module: nn"] + +import torch +import unittest + +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests +from torch.testing._internal.common_cuda import TEST_CUDA + +if TEST_FAIRSEQ: + import fairseq.models.transformer as fairseq_transformer + +class TestTransformers(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + @parametrize("use_torchscript", [True, False]) + @parametrize("with_no_grad", [True, False]) + @parametrize("training", [True, False]) + def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): + """ + Test TransformerEncoder does not crash + """ + model = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), + num_layers=2, + enable_nested_tensor=True + ) + + if training: + model = model.train() + else: + model = model.eval() + + if use_torchscript: + model = torch.jit.script(model) + + x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) + mask = torch.Tensor([[0, 1]]).to(torch.bool) + + if with_no_grad: + with torch.no_grad(): + model(x, src_key_padding_mask=mask) + else: + model(x, src_key_padding_mask=mask) + + @unittest.skipIf(not TEST_FAIRSEQ, "numpy not found") + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_decoder_only_layer(self): + DEFAULT_PADDING_IDX = 0 + + class FairseqDecoder(torch.nn.Module): + def __init__( + self, + embed_dim, + attention_heads, + ffn_embed_dim, + num_layers, + embedding_layer, # torch.nn.Embedding. Must have a padding_idx field + dropout=0, + normalize_before=False, + torch_encoder=None, # torch encoder that you can map weights from + activation="relu", + ): + super().__init__() + + cfg = fairseq_transformer.TransformerConfig() + cfg.decoder.embed_dim = embed_dim + cfg.decoder.output_dim = embed_dim + cfg.decoder.attention_heads = attention_heads + cfg.decoder.ffn_embed_dim = ffn_embed_dim + cfg.dropout = dropout + cfg.decoder.normalize_before = normalize_before + cfg.decoder.layers = num_layers + # make embedding behavior same as other encoders + cfg.no_token_positional_embeddings = True + cfg.no_scale_embedding = True + cfg.activation_fn = activation + + dictionary = {} # TODO: verify what this is + + self.decoder = fairseq_transformer.TransformerDecoder( + cfg, + dictionary, + embedding_layer, + no_encoder_attn=True, + output_projection=None, + ) + + if torch_encoder is not None: + self.decoder = torch_to_fairseq(torch_encoder, self.decoder) + self.decoder = self.decoder.eval().cuda().half() + + def forward( + self, + tokens, + src_lengths=None, + with_triangle_mask=False, + incremental_state=None, + ): + return self.decoder( + prev_output_tokens=tokens, + encoder_out=None, + incremental_state=incremental_state, + features_only=True, + full_context_alignment=not with_triangle_mask, + alignment_layer=None, + alignment_heads=None, + src_lengths=src_lengths, + return_all_hiddens=False, + )[0] + + class BetterDecoder(torch.nn.Module): + """ + Only incremental decoder for now + """ + + def __init__(self, transformer, embedding, pad_idx): + super().__init__() + self.transformer = transformer + self.embedding = embedding + self.padding_idx = pad_idx + + def forward( + self, + x, + src_mask=None, + include_padding_mask=True, + incr_key_lst=None, + incr_value_lst=None, + is_incremental_decoding=False, + ): + padding_mask = None + if not x.is_nested and include_padding_mask: + padding_mask = x.eq(self.padding_idx) + if(is_incremental_decoding): + x = x[:, -1:] # only take the last token + x = self.embedding(x) + + one_encoder_layer = self.transformer.layers[0] + self_attn = one_encoder_layer.self_attn + embed_dim = self_attn.embed_dim + num_heads = self_attn.num_heads + + use_gelu = ( + one_encoder_layer.activation_relu_or_gelu == 2 + ) # see torch/nn/modules/activation attention impl. 1 == relu, 2 == gelu + assert ( + one_encoder_layer.activation_relu_or_gelu != 0 + ) # 0 == not relu or gelu + + norm_first = one_encoder_layer.norm_first + + + # TODO: make this a bit less janky. but for now we initialize with an empty tensor. + if(not is_incremental_decoding): + assert len(incr_key_lst) == 0 or incr_key_lst[0] is None + assert len(incr_value_lst) == 0 or incr_value_lst[0] is None + while len(incr_key_lst) <= len(self.transformer.layers): + if(is_incremental_decoding): + incr_key_lst.append(torch.Tensor([]).cuda().half()) + incr_value_lst.append(torch.Tensor([]).cuda().half()) + else: + incr_key_lst.append(None) + incr_value_lst.append(None) + + for i, layer in enumerate(self.transformer.layers): + incr_key = incr_key_lst[i] + incr_value = incr_value_lst[i] + + x, incr_key, incr_value = torch._transformer_decoder_only_layer_fwd( + src=x, + embed_dim=embed_dim, + num_heads=num_heads, + qkv_weight=layer.self_attn.in_proj_weight, + qkv_bias=layer.self_attn.in_proj_bias, + proj_weight=layer.self_attn.out_proj.weight, + proj_bias=layer.self_attn.out_proj.bias, + use_gelu=use_gelu, + norm_first=norm_first, + # TODO: layer_norm_eps hardcoded to be same as nn.TransformerEncoder default. + # fix by pulling from self_attn.norm1 + eps=1e-5, + norm_weight_1=layer.norm1.weight, + norm_bias_1=layer.norm1.bias, + norm_weight_2=layer.norm2.weight, + norm_bias_2=layer.norm2.bias, + ffn_weight_1=layer.linear1.weight, + ffn_bias_1=layer.linear1.bias, + ffn_weight_2=layer.linear2.weight, + ffn_bias_2=layer.linear2.bias, + mask=src_mask, + incr_key=incr_key, # altered in place + incr_value=incr_value, + ) + + # not in-place + if(not is_incremental_decoding): + incr_key = None + incr_value = None + incr_key_lst[i] = incr_key + incr_value_lst[i] = incr_value + + return x, incr_key_lst, incr_value_lst + + def torch_to_fairseq(torch_encoder, fairseq_encoder): + for src_layer, dst_layer in zip(torch_encoder.layers, fairseq_encoder.layers): + w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0) + b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0) + + dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q) + dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q) + dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k) + dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k) + dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v) + dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v) + + dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight + dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias + + dst_layer.fc1.weight = src_layer.linear1.weight + dst_layer.fc1.bias = src_layer.linear1.bias + + # fairseq may use fusedlayernorm from nvidia apex - diff properties + dst_layer.self_attn_layer_norm.load_state_dict(src_layer.norm1.state_dict()) + + dst_layer.fc2.weight = src_layer.linear2.weight + dst_layer.fc2.bias = src_layer.linear2.bias + + dst_layer.final_layer_norm.load_state_dict(src_layer.norm2.state_dict()) + + return fairseq_encoder + + def set_weights_deterministic(model): + for idx, p in enumerate(model.parameters()): + x = p.data + sz = x.view(-1).size(0) + shape = x.shape + x = torch.cos(torch.arange(0, sz).float().view(shape)) + p.data.copy_(x) + + D = 4 # d_model + H = 2 # nhead + FD = 16 # dim_feedforward + V = 100 # vocab size + L = 2 # num layers + + embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX) + layer = torch.nn.TransformerEncoderLayer( + d_model=D, + nhead=H, + dim_feedforward=FD, + batch_first=True, + activation="gelu", + ) + transformer = torch.nn.TransformerEncoder( + layer, + num_layers=L, + ).eval().cuda().half() + + set_weights_deterministic(embedding_layer) + set_weights_deterministic(transformer) + + better_decoder = ( + BetterDecoder(transformer, embedding_layer, DEFAULT_PADDING_IDX) + .eval() + .cuda() + .half() + ) + fairseq_decoder = ( + FairseqDecoder( + D, + H, + FD, + L, + embedding_layer, + dropout=0, + normalize_before=False, + torch_encoder=transformer, + activation="gelu", + ) + .eval() + .cuda() + .half() + ) + + tokens = torch.Tensor([ + [5, 6, 7, 8], + [9, 10, 11, 12] + ]).to(torch.int).cuda() + lengths_tensor = torch.Tensor([2, 2]).to(torch.int).cuda() + # bs = 2, seqlen = 4 + bs, seqlen = tokens.shape + + upper_triangle = torch.zeros(seqlen, seqlen) + upper_triangle.fill_(-100000000) + upper_triangle = torch.triu(upper_triangle, 1) + upper_triangle = upper_triangle.cuda().half() + upper_triangle_expanded = upper_triangle.unsqueeze(0).unsqueeze(0) + upper_triangle_expanded = upper_triangle_expanded.expand( + bs, H, -1, -1 + ) + + # test forced decoding + with torch.no_grad(): + result, _, _ = better_decoder( + tokens, + src_mask=upper_triangle_expanded, + include_padding_mask=False, + incr_key_lst=[], + incr_value_lst=[], + is_incremental_decoding=False, + ) + ref_output = fairseq_decoder(tokens, lengths_tensor, with_triangle_mask=True) + + self.assertEqual(result.shape, ref_output.shape) + torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) + + # test incremental decoding + bs, seqlen = tokens.shape + + incr_state = {} + ref_outputs = [fairseq_decoder( + tokens[:, :i], + src_lengths=None, + with_triangle_mask=False, + incremental_state=incr_state, + ) for i in range(1, seqlen + 1)] + ref_output = torch.stack(ref_outputs) + + incr_key_lst = [] + incr_value_lst = [] + results = [] + for i in range(1, seqlen + 1): + res, incr_key_lst, incr_value_lst = better_decoder( + tokens[:, :i], + src_mask=None, + include_padding_mask=False, + incr_key_lst=incr_key_lst, + incr_value_lst=incr_value_lst, + is_incremental_decoding=True, + ) + results.append(res) + result = torch.stack(results) + + self.assertEqual(result.shape, ref_output.shape) + torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) + +instantiate_parametrized_tests(TestTransformers) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e1417ed1dc41..9c792ae090ab 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -758,6 +758,7 @@ def _check_module_exists(name: str) -> bool: return False TEST_NUMPY = _check_module_exists('numpy') +TEST_FAIRSEQ = _check_module_exists('fairseq') TEST_SCIPY = _check_module_exists('scipy') TEST_MKL = torch.backends.mkl.is_available() TEST_CUDA = torch.cuda.is_available()