Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add test for torchscripting nn.TransformerEncoder, including fast path (#79796) (#79796) Summary: Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training]. Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail. This test goes together with this fix #79480 Pull Request resolved: #79796 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/06274d7a487bf7995da77b9df9b5c1f7dc13f35b Test plan from GitHub: ``` buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace //caffe2/test:transformers ./fbcode/buck-out/gen/caffe2/test/transformers#binary.par ``` Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix. Reviewed By: mikekgfb Differential Revision: D37222923 Pulled By: erichan1 fbshipit-source-id: 5a16e7d240cb51c0a613d16a79931d41122aba8b * disable src mask for transformer and multiheadattention fastpath (#81277) (#81277) Summary: Disable fastpath if src_mask passed to TransformerEncoderLayer and MultiheadAttention. - Refactored test_transformerencoder from test_nn.py to test_transformers.py. Added a src_mask test there. - Added a specific src_mask test in test_transformers.py Fixes #81129 Pull Request resolved: #81277 Approved by: https://github.com/zrphercule Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/23088fcfdf77632d4e6db4d35ce62735ca6622d2 Reviewed By: DanilBaibak Differential Revision: D37919513 Pulled By: erichan1 fbshipit-source-id: 0697d789634775136897fdb6a310356a6a45030d * remove decoder tests for feature not in 1.12 * remove unnecessary changes from #77903 to make changes more minimal
- Loading branch information
Showing
5 changed files
with
174 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Owner(s): ["module: nn"] | ||
|
||
import contextlib | ||
import torch | ||
import unittest | ||
|
||
from torch.testing._internal.common_nn import NNTestCase | ||
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests | ||
from torch.testing._internal.common_cuda import TEST_CUDA | ||
|
||
@contextlib.contextmanager | ||
def set_default_dtype(dtype): | ||
saved_dtype = torch.get_default_dtype() | ||
torch.set_default_dtype(dtype) | ||
try: | ||
yield | ||
finally: | ||
torch.set_default_dtype(saved_dtype) | ||
|
||
class TestTransformers(NNTestCase): | ||
_do_cuda_memory_leak_check = True | ||
_do_cuda_non_default_stream = True | ||
|
||
device_list = ['cpu'] # TODO: is there a way to do parametrize for this? | ||
if TEST_CUDA: | ||
device_list.append('cuda') | ||
|
||
@unittest.skip("4D mask not supported yet - activate when 4D mask supported") | ||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu | ||
def test_self_attn_TxT_attn_mask(self): | ||
embed_dim = 16 | ||
num_heads = 4 | ||
batch_size = 10 | ||
tgt_len = 16 | ||
|
||
query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D] | ||
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T] | ||
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) | ||
|
||
attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len) | ||
|
||
mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda() | ||
mta_model.eval() | ||
|
||
# Generate 3D results | ||
with torch.inference_mode(): | ||
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0] | ||
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D] | ||
|
||
output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0] | ||
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D] | ||
|
||
self.assertEqual(output_mask_4d, output_mask_TxT) | ||
|
||
@parametrize("device", device_list) | ||
def test_transformerencoderlayer_src_mask(self, device): | ||
batch_size = 2 | ||
seqlen = 4 | ||
d_model = 8 | ||
nhead = 8 | ||
dim_feedforward = 32 | ||
|
||
model = torch.nn.TransformerEncoderLayer( | ||
d_model=d_model, | ||
nhead=nhead, | ||
dim_feedforward=dim_feedforward, | ||
batch_first=True).to(device) | ||
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model | ||
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device) | ||
|
||
model(src, src_mask=src_mask) | ||
model.eval() | ||
with torch.no_grad(): | ||
model(src, src_mask=src_mask) | ||
|
||
@parametrize("use_torchscript", [True, False]) | ||
@parametrize("with_no_grad", [True, False]) | ||
@parametrize("training", [True, False]) | ||
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): | ||
""" | ||
Test TransformerEncoder does not crash | ||
""" | ||
model = torch.nn.TransformerEncoder( | ||
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), | ||
num_layers=2, | ||
enable_nested_tensor=True | ||
) | ||
|
||
if training: | ||
model = model.train() | ||
else: | ||
model = model.eval() | ||
|
||
if use_torchscript: | ||
model = torch.jit.script(model) | ||
|
||
x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) | ||
mask = torch.Tensor([[0, 1]]).to(torch.bool) | ||
|
||
if with_no_grad: | ||
cm = torch.no_grad() | ||
else: | ||
cm = contextlib.nullcontext() | ||
with cm: | ||
model(x, src_key_padding_mask=mask) | ||
|
||
@parametrize("with_no_grad", [True, False]) | ||
@parametrize("training", [True, False]) | ||
@parametrize("enable_nested_tensor", [False]) | ||
@parametrize("device", device_list) | ||
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device): | ||
""" | ||
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has | ||
batch size == sequence length | ||
""" | ||
model = torch.nn.TransformerEncoder( | ||
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), | ||
num_layers=2, | ||
enable_nested_tensor=enable_nested_tensor | ||
).to(device) | ||
|
||
with torch.no_grad(): | ||
# set constant weights of the model | ||
for idx, p in enumerate(model.parameters()): | ||
x = p.data | ||
sz = x.view(-1).size(0) | ||
shape = x.shape | ||
x = torch.cos(torch.arange(0, sz).float().view(shape)) | ||
p.data.copy_(x) | ||
|
||
if training: | ||
model = model.train() | ||
else: | ||
model = model.eval() | ||
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device) | ||
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) | ||
|
||
if with_no_grad: | ||
cm = torch.no_grad() | ||
else: | ||
cm = contextlib.nullcontext() | ||
with cm: | ||
result = model(x, mask=src_mask) | ||
|
||
ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], | ||
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], | ||
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], | ||
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]] | ||
).to(device) | ||
self.assertEqual(tuple(result.shape), tuple(ref_output.shape)) | ||
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5) | ||
|
||
instantiate_parametrized_tests(TestTransformers) | ||
|
||
if __name__ == '__main__': | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters