Skip to content

Commit

Permalink
Fix cuda/cpu check on NoneType (pytorch#88854)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#88854

Fix cuda/cpu check on NoneType

Test Plan: sabdcastle/ github CI/CD

Differential Revision: D41203955

fbshipit-source-id: f8d14b3e241a149b35266b134ad622b29379a6d1
  • Loading branch information
mikekgfb authored and facebook-github-bot committed Nov 11, 2022
1 parent 495e7b1 commit 18d0333
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions test/test_transformers.py
Expand Up @@ -1132,6 +1132,15 @@ def make_tensor(*size, device=device, dtype=dtype):
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False, False))

# Test failing MHA when bias was NoneType
def test_bias_is_none(self):
x = torch.rand((1, 5, 10))
model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
model.eval()
model(x, x, x)
# completes without error


# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
instantiate_parametrized_tests(TestTransformers)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/activation.py
Expand Up @@ -1114,7 +1114,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
Expand Down

0 comments on commit 18d0333

Please sign in to comment.