Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add test for torchscripting nn.TransformerEncoder, including fast path (
#79796) (#79796) Summary: Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training]. Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail. This test goes together with this fix #79480 Pull Request resolved: #79796 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/06274d7a487bf7995da77b9df9b5c1f7dc13f35b Test plan from GitHub: ``` buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace //caffe2/test:transformers ./fbcode/buck-out/gen/caffe2/test/transformers#binary.par ``` Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix. Reviewed By: mikekgfb Differential Revision: D37222923 Pulled By: erichan1 fbshipit-source-id: 5a16e7d240cb51c0a613d16a79931d41122aba8b
- Loading branch information
Showing
2 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
# Owner(s): ["module: nn"] | ||
|
||
import torch | ||
import unittest | ||
|
||
from torch.testing._internal.common_nn import NNTestCase | ||
from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests | ||
from torch.testing._internal.common_cuda import TEST_CUDA | ||
|
||
if TEST_FAIRSEQ: | ||
import fairseq.models.transformer as fairseq_transformer | ||
|
||
class TestTransformers(NNTestCase): | ||
_do_cuda_memory_leak_check = True | ||
_do_cuda_non_default_stream = True | ||
|
||
@parametrize("use_torchscript", [True, False]) | ||
@parametrize("with_no_grad", [True, False]) | ||
@parametrize("training", [True, False]) | ||
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): | ||
""" | ||
Test TransformerEncoder does not crash | ||
""" | ||
model = torch.nn.TransformerEncoder( | ||
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), | ||
num_layers=2, | ||
enable_nested_tensor=True | ||
) | ||
|
||
if training: | ||
model = model.train() | ||
else: | ||
model = model.eval() | ||
|
||
if use_torchscript: | ||
model = torch.jit.script(model) | ||
|
||
x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) | ||
mask = torch.Tensor([[0, 1]]).to(torch.bool) | ||
|
||
if with_no_grad: | ||
with torch.no_grad(): | ||
model(x, src_key_padding_mask=mask) | ||
else: | ||
model(x, src_key_padding_mask=mask) | ||
|
||
@unittest.skipIf(not TEST_FAIRSEQ, "numpy not found") | ||
@unittest.skipIf(not TEST_CUDA, 'CUDA not available') | ||
def test_decoder_only_layer(self): | ||
DEFAULT_PADDING_IDX = 0 | ||
|
||
class FairseqDecoder(torch.nn.Module): | ||
def __init__( | ||
self, | ||
embed_dim, | ||
attention_heads, | ||
ffn_embed_dim, | ||
num_layers, | ||
embedding_layer, # torch.nn.Embedding. Must have a padding_idx field | ||
dropout=0, | ||
normalize_before=False, | ||
torch_encoder=None, # torch encoder that you can map weights from | ||
activation="relu", | ||
): | ||
super().__init__() | ||
|
||
cfg = fairseq_transformer.TransformerConfig() | ||
cfg.decoder.embed_dim = embed_dim | ||
cfg.decoder.output_dim = embed_dim | ||
cfg.decoder.attention_heads = attention_heads | ||
cfg.decoder.ffn_embed_dim = ffn_embed_dim | ||
cfg.dropout = dropout | ||
cfg.decoder.normalize_before = normalize_before | ||
cfg.decoder.layers = num_layers | ||
# make embedding behavior same as other encoders | ||
cfg.no_token_positional_embeddings = True | ||
cfg.no_scale_embedding = True | ||
cfg.activation_fn = activation | ||
|
||
dictionary = {} # TODO: verify what this is | ||
|
||
self.decoder = fairseq_transformer.TransformerDecoder( | ||
cfg, | ||
dictionary, | ||
embedding_layer, | ||
no_encoder_attn=True, | ||
output_projection=None, | ||
) | ||
|
||
if torch_encoder is not None: | ||
self.decoder = torch_to_fairseq(torch_encoder, self.decoder) | ||
self.decoder = self.decoder.eval().cuda().half() | ||
|
||
def forward( | ||
self, | ||
tokens, | ||
src_lengths=None, | ||
with_triangle_mask=False, | ||
incremental_state=None, | ||
): | ||
return self.decoder( | ||
prev_output_tokens=tokens, | ||
encoder_out=None, | ||
incremental_state=incremental_state, | ||
features_only=True, | ||
full_context_alignment=not with_triangle_mask, | ||
alignment_layer=None, | ||
alignment_heads=None, | ||
src_lengths=src_lengths, | ||
return_all_hiddens=False, | ||
)[0] | ||
|
||
class BetterDecoder(torch.nn.Module): | ||
""" | ||
Only incremental decoder for now | ||
""" | ||
|
||
def __init__(self, transformer, embedding, pad_idx): | ||
super().__init__() | ||
self.transformer = transformer | ||
self.embedding = embedding | ||
self.padding_idx = pad_idx | ||
|
||
def forward( | ||
self, | ||
x, | ||
src_mask=None, | ||
include_padding_mask=True, | ||
incr_key_lst=None, | ||
incr_value_lst=None, | ||
is_incremental_decoding=False, | ||
): | ||
padding_mask = None | ||
if not x.is_nested and include_padding_mask: | ||
padding_mask = x.eq(self.padding_idx) | ||
if(is_incremental_decoding): | ||
x = x[:, -1:] # only take the last token | ||
x = self.embedding(x) | ||
|
||
one_encoder_layer = self.transformer.layers[0] | ||
self_attn = one_encoder_layer.self_attn | ||
embed_dim = self_attn.embed_dim | ||
num_heads = self_attn.num_heads | ||
|
||
use_gelu = ( | ||
one_encoder_layer.activation_relu_or_gelu == 2 | ||
) # see torch/nn/modules/activation attention impl. 1 == relu, 2 == gelu | ||
assert ( | ||
one_encoder_layer.activation_relu_or_gelu != 0 | ||
) # 0 == not relu or gelu | ||
|
||
norm_first = one_encoder_layer.norm_first | ||
|
||
|
||
# TODO: make this a bit less janky. but for now we initialize with an empty tensor. | ||
if(not is_incremental_decoding): | ||
assert len(incr_key_lst) == 0 or incr_key_lst[0] is None | ||
assert len(incr_value_lst) == 0 or incr_value_lst[0] is None | ||
while len(incr_key_lst) <= len(self.transformer.layers): | ||
if(is_incremental_decoding): | ||
incr_key_lst.append(torch.Tensor([]).cuda().half()) | ||
incr_value_lst.append(torch.Tensor([]).cuda().half()) | ||
else: | ||
incr_key_lst.append(None) | ||
incr_value_lst.append(None) | ||
|
||
for i, layer in enumerate(self.transformer.layers): | ||
incr_key = incr_key_lst[i] | ||
incr_value = incr_value_lst[i] | ||
|
||
x, incr_key, incr_value = torch._transformer_decoder_only_layer_fwd( | ||
src=x, | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
qkv_weight=layer.self_attn.in_proj_weight, | ||
qkv_bias=layer.self_attn.in_proj_bias, | ||
proj_weight=layer.self_attn.out_proj.weight, | ||
proj_bias=layer.self_attn.out_proj.bias, | ||
use_gelu=use_gelu, | ||
norm_first=norm_first, | ||
# TODO: layer_norm_eps hardcoded to be same as nn.TransformerEncoder default. | ||
# fix by pulling from self_attn.norm1 | ||
eps=1e-5, | ||
norm_weight_1=layer.norm1.weight, | ||
norm_bias_1=layer.norm1.bias, | ||
norm_weight_2=layer.norm2.weight, | ||
norm_bias_2=layer.norm2.bias, | ||
ffn_weight_1=layer.linear1.weight, | ||
ffn_bias_1=layer.linear1.bias, | ||
ffn_weight_2=layer.linear2.weight, | ||
ffn_bias_2=layer.linear2.bias, | ||
mask=src_mask, | ||
incr_key=incr_key, # altered in place | ||
incr_value=incr_value, | ||
) | ||
|
||
# not in-place | ||
if(not is_incremental_decoding): | ||
incr_key = None | ||
incr_value = None | ||
incr_key_lst[i] = incr_key | ||
incr_value_lst[i] = incr_value | ||
|
||
return x, incr_key_lst, incr_value_lst | ||
|
||
def torch_to_fairseq(torch_encoder, fairseq_encoder): | ||
for src_layer, dst_layer in zip(torch_encoder.layers, fairseq_encoder.layers): | ||
w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0) | ||
b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0) | ||
|
||
dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q) | ||
dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q) | ||
dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k) | ||
dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k) | ||
dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v) | ||
dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v) | ||
|
||
dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight | ||
dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias | ||
|
||
dst_layer.fc1.weight = src_layer.linear1.weight | ||
dst_layer.fc1.bias = src_layer.linear1.bias | ||
|
||
# fairseq may use fusedlayernorm from nvidia apex - diff properties | ||
dst_layer.self_attn_layer_norm.load_state_dict(src_layer.norm1.state_dict()) | ||
|
||
dst_layer.fc2.weight = src_layer.linear2.weight | ||
dst_layer.fc2.bias = src_layer.linear2.bias | ||
|
||
dst_layer.final_layer_norm.load_state_dict(src_layer.norm2.state_dict()) | ||
|
||
return fairseq_encoder | ||
|
||
def set_weights_deterministic(model): | ||
for idx, p in enumerate(model.parameters()): | ||
x = p.data | ||
sz = x.view(-1).size(0) | ||
shape = x.shape | ||
x = torch.cos(torch.arange(0, sz).float().view(shape)) | ||
p.data.copy_(x) | ||
|
||
D = 4 # d_model | ||
H = 2 # nhead | ||
FD = 16 # dim_feedforward | ||
V = 100 # vocab size | ||
L = 2 # num layers | ||
|
||
embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX) | ||
layer = torch.nn.TransformerEncoderLayer( | ||
d_model=D, | ||
nhead=H, | ||
dim_feedforward=FD, | ||
batch_first=True, | ||
activation="gelu", | ||
) | ||
transformer = torch.nn.TransformerEncoder( | ||
layer, | ||
num_layers=L, | ||
).eval().cuda().half() | ||
|
||
set_weights_deterministic(embedding_layer) | ||
set_weights_deterministic(transformer) | ||
|
||
better_decoder = ( | ||
BetterDecoder(transformer, embedding_layer, DEFAULT_PADDING_IDX) | ||
.eval() | ||
.cuda() | ||
.half() | ||
) | ||
fairseq_decoder = ( | ||
FairseqDecoder( | ||
D, | ||
H, | ||
FD, | ||
L, | ||
embedding_layer, | ||
dropout=0, | ||
normalize_before=False, | ||
torch_encoder=transformer, | ||
activation="gelu", | ||
) | ||
.eval() | ||
.cuda() | ||
.half() | ||
) | ||
|
||
tokens = torch.Tensor([ | ||
[5, 6, 7, 8], | ||
[9, 10, 11, 12] | ||
]).to(torch.int).cuda() | ||
lengths_tensor = torch.Tensor([2, 2]).to(torch.int).cuda() | ||
# bs = 2, seqlen = 4 | ||
bs, seqlen = tokens.shape | ||
|
||
upper_triangle = torch.zeros(seqlen, seqlen) | ||
upper_triangle.fill_(-100000000) | ||
upper_triangle = torch.triu(upper_triangle, 1) | ||
upper_triangle = upper_triangle.cuda().half() | ||
upper_triangle_expanded = upper_triangle.unsqueeze(0).unsqueeze(0) | ||
upper_triangle_expanded = upper_triangle_expanded.expand( | ||
bs, H, -1, -1 | ||
) | ||
|
||
# test forced decoding | ||
with torch.no_grad(): | ||
result, _, _ = better_decoder( | ||
tokens, | ||
src_mask=upper_triangle_expanded, | ||
include_padding_mask=False, | ||
incr_key_lst=[], | ||
incr_value_lst=[], | ||
is_incremental_decoding=False, | ||
) | ||
ref_output = fairseq_decoder(tokens, lengths_tensor, with_triangle_mask=True) | ||
|
||
self.assertEqual(result.shape, ref_output.shape) | ||
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) | ||
|
||
# test incremental decoding | ||
bs, seqlen = tokens.shape | ||
|
||
incr_state = {} | ||
ref_outputs = [fairseq_decoder( | ||
tokens[:, :i], | ||
src_lengths=None, | ||
with_triangle_mask=False, | ||
incremental_state=incr_state, | ||
) for i in range(1, seqlen + 1)] | ||
ref_output = torch.stack(ref_outputs) | ||
|
||
incr_key_lst = [] | ||
incr_value_lst = [] | ||
results = [] | ||
for i in range(1, seqlen + 1): | ||
res, incr_key_lst, incr_value_lst = better_decoder( | ||
tokens[:, :i], | ||
src_mask=None, | ||
include_padding_mask=False, | ||
incr_key_lst=incr_key_lst, | ||
incr_value_lst=incr_value_lst, | ||
is_incremental_decoding=True, | ||
) | ||
results.append(res) | ||
result = torch.stack(results) | ||
|
||
self.assertEqual(result.shape, ref_output.shape) | ||
torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) | ||
|
||
instantiate_parametrized_tests(TestTransformers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters