Skip to content

Commit

Permalink
1.12.1/bt fix (#81952)
Browse files Browse the repository at this point in the history
* Add test for torchscripting nn.TransformerEncoder, including fast path (#79796) (#79796)

Summary:
Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training].

Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail.

This test goes together with this fix #79480

Pull Request resolved: #79796
Approved by: https://github.com/zrphercule

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/06274d7a487bf7995da77b9df9b5c1f7dc13f35b

Test plan from GitHub:
```
buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace  //caffe2/test:transformers

./fbcode/buck-out/gen/caffe2/test/transformers#binary.par
```
Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix.

Reviewed By: mikekgfb

Differential Revision: D37222923

Pulled By: erichan1

fbshipit-source-id: 5a16e7d240cb51c0a613d16a79931d41122aba8b

* disable src mask for transformer and multiheadattention fastpath (#81277) (#81277)

Summary:
Disable fastpath if src_mask passed to TransformerEncoderLayer and MultiheadAttention.
- Refactored test_transformerencoder from test_nn.py to test_transformers.py. Added a src_mask test there.
- Added a specific src_mask test in test_transformers.py

Fixes #81129

Pull Request resolved: #81277
Approved by: https://github.com/zrphercule

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/23088fcfdf77632d4e6db4d35ce62735ca6622d2

Reviewed By: DanilBaibak

Differential Revision: D37919513

Pulled By: erichan1

fbshipit-source-id: 0697d789634775136897fdb6a310356a6a45030d

* remove decoder tests for feature not in 1.12

* remove unnecessary changes from #77903 to make changes more minimal
  • Loading branch information
erichan1 committed Jul 25, 2022
1 parent e8534b9 commit e65e4ac
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 35 deletions.
12 changes: 11 additions & 1 deletion aten/src/ATen/native/nested/NestedTensorMath.cpp
Expand Up @@ -316,7 +316,17 @@ Tensor nested_from_padded_generic(
padded.size(2),
padded.size(1) * padded.size(3)});
}
const auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
auto target_size = NestedTensor_get_max_size_from_size_tensor(sizes);
// There may be extra padding on padded beyond the max size in the nested tensor.
// Make the mask size match.
const size_t dim = padded_transformed.dim();
TORCH_CHECK(dim - 1 == target_size.size(), "dim: ", dim, "target_size: ", target_size.size());
for (size_t ii = 0; ii < dim - 1; ++ii) {
const auto padded_size_i = padded_transformed.sizes()[ii + 1];
if (target_size[ii] < padded_size_i) {
target_size[ii] = padded_size_i;
}
}
IntArrayRef target_size_arr(target_size);
std::vector<at::Tensor> masks;
std::vector<at::Tensor> all_sizes = sizes.unbind();
Expand Down
26 changes: 0 additions & 26 deletions test/test_nn.py
Expand Up @@ -5826,32 +5826,6 @@ def test_multihead_attn_3d_attn_mask(self):
# output_2d in shape of [T, 1, D]
self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_self_attn_TxT_attn_mask(self):
embed_dim = 16
num_heads = 4
batch_size = 10
tgt_len = 16

query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D]
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T]
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))

attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)

mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
mta_model.eval()

# Generate 3D results
with torch.inference_mode():
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D]

output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D]

self.assertEqual(output_mask_4d, output_mask_TxT)

def test_multihead_attn_no_bias(self):
embed_dim = 8
num_heads = 4
Expand Down
156 changes: 156 additions & 0 deletions test/test_transformers.py
@@ -0,0 +1,156 @@
# Owner(s): ["module: nn"]

import contextlib
import torch
import unittest

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from torch.testing._internal.common_cuda import TEST_CUDA

@contextlib.contextmanager
def set_default_dtype(dtype):
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(saved_dtype)

class TestTransformers(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True

device_list = ['cpu'] # TODO: is there a way to do parametrize for this?
if TEST_CUDA:
device_list.append('cuda')

@unittest.skip("4D mask not supported yet - activate when 4D mask supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") # TODO: make this work for both cuda and cpu
def test_self_attn_TxT_attn_mask(self):
embed_dim = 16
num_heads = 4
batch_size = 10
tgt_len = 16

query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda") # [N, T, D]
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float() # [T, T]
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))

attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)

mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
mta_model.eval()

# Generate 3D results
with torch.inference_mode():
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
output_mask_4d = output_mask_4d.transpose(0, 1) # [N, T, D]

output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
output_mask_TxT = output_mask_TxT.transpose(0, 1) # [N, T, D]

self.assertEqual(output_mask_4d, output_mask_TxT)

@parametrize("device", device_list)
def test_transformerencoderlayer_src_mask(self, device):
batch_size = 2
seqlen = 4
d_model = 8
nhead = 8
dim_feedforward = 32

model = torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True).to(device)
src = torch.rand(batch_size, seqlen, d_model).to(device) # bs, seqlen, d_model
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)

model(src, src_mask=src_mask)
model.eval()
with torch.no_grad():
model(src, src_mask=src_mask)

@parametrize("use_torchscript", [True, False])
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training):
"""
Test TransformerEncoder does not crash
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True),
num_layers=2,
enable_nested_tensor=True
)

if training:
model = model.train()
else:
model = model.eval()

if use_torchscript:
model = torch.jit.script(model)

x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float)
mask = torch.Tensor([[0, 1]]).to(torch.bool)

if with_no_grad:
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
model(x, src_key_padding_mask=mask)

@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
@parametrize("device", device_list)
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
"""
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
batch size == sequence length
"""
model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
num_layers=2,
enable_nested_tensor=enable_nested_tensor
).to(device)

with torch.no_grad():
# set constant weights of the model
for idx, p in enumerate(model.parameters()):
x = p.data
sz = x.view(-1).size(0)
shape = x.shape
x = torch.cos(torch.arange(0, sz).float().view(shape))
p.data.copy_(x)

if training:
model = model.train()
else:
model = model.eval()
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device)
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)

if with_no_grad:
cm = torch.no_grad()
else:
cm = contextlib.nullcontext()
with cm:
result = model(x, mask=src_mask)

ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)

instantiate_parametrized_tests(TestTransformers)

if __name__ == '__main__':
run_tests()
8 changes: 4 additions & 4 deletions torch/nn/modules/activation.py
Expand Up @@ -1085,10 +1085,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
why_not_fast_path = "add_zero_attn was enabled"
elif not self._qkv_same_embed_dim:
why_not_fast_path = "_qkv_same_embed_dim was not True"
elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input"
elif not query.is_nested and key_padding_mask is not None and attn_mask is not None:
why_not_fast_path = "key_padding_mask and attn_mask were both supplied"
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"

if not why_not_fast_path:
tensor_args = (
Expand Down
7 changes: 3 additions & 4 deletions torch/nn/modules/transformer.py
Expand Up @@ -411,9 +411,8 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
self.self_attn.batch_first and
self.self_attn._qkv_same_embed_dim and self.activation_relu_or_gelu and
self.norm1.eps == self.norm2.eps and
((src_mask is None and src_key_padding_mask is None)
if src.is_nested
else (src_mask is None or src_key_padding_mask is None))):
src_mask is None and
not (src.is_nested and src_key_padding_mask is not None)):
tensor_args = (
src,
self.self_attn.in_proj_weight,
Expand Down Expand Up @@ -453,7 +452,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
self.linear1.bias,
self.linear2.weight,
self.linear2.bias,
src_mask if src_mask is not None else src_key_padding_mask,
src_mask if src_mask is not None else src_key_padding_mask, # TODO: split into two args
)
x = src
if self.norm_first:
Expand Down

0 comments on commit e65e4ac

Please sign in to comment.