Skip to content

Commit

Permalink
Check for contiguous well-formed mask (#79927)
Browse files Browse the repository at this point in the history
Summary: Check for contiguous well-formed mask

Test Plan: sandcastle, github CI

Reviewed By: frank-wei

Differential Revision: D37301243

Pull Request resolved: #79927
Approved by: https://github.com/jbschlosser
  • Loading branch information
mikekgfb authored and erichan1 committed Jul 21, 2022
1 parent 37b49cf commit ae79b3e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
29 changes: 16 additions & 13 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7906,7 +7906,8 @@ def get_a_test_layer(use_cuda, activation, batch_first=False):
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first).to(device)
batch_first=batch_first,
).to(device)

with torch.no_grad():
# set constant weights of the model
Expand All @@ -7924,7 +7925,7 @@ def get_a_test_layer(use_cuda, activation, batch_first=False):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def _test(batch_first, training):
def _test(batch_first, training, enable_nested_tensor):
def perm_fn(x):
return x.transpose(1, 0) if batch_first else x

Expand Down Expand Up @@ -7972,7 +7973,7 @@ def perm_fn(x):
mask[1, 4] = 1
# If mask is not left aligned
# We disable nested tensor
model.enable_nested_tensor = False
model.enable_nested_tensor = enable_nested_tensor
result = model(encoder_input, src_key_padding_mask=mask)
ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
[2.428811, 0.021445, -0.601912, -0.084252]],
Expand All @@ -7989,7 +7990,7 @@ def perm_fn(x):
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)

# test case 2, multiple layers no norm
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=False).to(device)
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
Expand All @@ -8007,7 +8008,7 @@ def perm_fn(x):
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)

model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=False).to(device)
model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
Expand All @@ -8028,7 +8029,7 @@ def perm_fn(x):
# test case 3, multiple layers with norm
# d_model = 4
norm = nn.LayerNorm(4)
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=False).to(device)
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
if not training:
model = model.eval()
result = model(encoder_input, src_key_padding_mask=mask)
Expand Down Expand Up @@ -8063,15 +8064,17 @@ def perm_fn(x):
)).to(device)
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)

for batch_first in (True, False):
for training in (True, False):
# Fast path requires inference mode.
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad()
with cm:
_test(batch_first, training)
for enable_nested_tensor in (True, False):
# Fast path requires inference mode.
if training:
cm = contextlib.nullcontext()
else:
cm = torch.no_grad()
with cm:
_test(batch_first, training, enable_nested_tensor)

def test_transformerdecoder(self):
def get_a_test_layer(use_cuda, activation, batch_first=False):
Expand Down
11 changes: 6 additions & 5 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,15 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
output = src
convert_to_nested = False
first_layer = self.layers[0]
src_key_padding_mask_for_layers = src_key_padding_mask
if isinstance(first_layer, torch.nn.TransformerEncoderLayer):
if (not first_layer.norm_first and not first_layer.training and
first_layer.self_attn.batch_first and
first_layer.self_attn._qkv_same_embed_dim and first_layer.activation_relu_or_gelu and
first_layer.norm1.eps == first_layer.norm2.eps and
src.dim() == 3 and self.enable_nested_tensor) :
src.dim() == 3 and self.enable_nested_tensor and
src_key_padding_mask is not None and
torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
if src_key_padding_mask is not None and not output.is_nested and mask is None:
tensor_args = (
src,
Expand All @@ -230,12 +233,10 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
if output.is_cuda or 'cpu' in str(output.device):
convert_to_nested = True
output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not())
src_key_padding_mask_for_layers = None

for mod in self.layers:
if convert_to_nested:
output = mod(output, src_mask=mask)
else:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)

if convert_to_nested:
output = output.to_padded_tensor(0.)
Expand Down

0 comments on commit ae79b3e

Please sign in to comment.