Skip to content

Commit

Permalink
Fix cuda/cpu check on NoneType (Unit test) (pytorch#88970)
Browse files Browse the repository at this point in the history
Summary: Fix cuda/cpu check on NoneType (unit test)

Test Plan: sabdcastle/ github CI/CD

Differential Revision: D41208798

Pull Request resolved: pytorch#88970
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
  • Loading branch information
mikekgfb authored and kulinseth committed Dec 9, 2022
1 parent c3f5c8d commit 4681335
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions test/test_transformers.py
Expand Up @@ -1168,6 +1168,15 @@ def make_tensor(*size, device=device, dtype=dtype):
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False, False))

# Test failing MHA when bias was NoneType
def test_bias_is_none(self):
x = torch.rand((1, 5, 10))
model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
model.eval()
model(x, x, x)
# completes without error


# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
instantiate_parametrized_tests(TestTransformers)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/activation.py
Expand Up @@ -1113,7 +1113,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
Expand Down

0 comments on commit 4681335

Please sign in to comment.