Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cuda/cpu check on NoneType (Unit test) (#88970) #89934

Merged
merged 1 commit into from Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_transformers.py
Expand Up @@ -970,6 +970,15 @@ def make_tensor(*size, device=device, dtype=dtype):
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, torch.ones_like(q), 0.0, False, False))

# Test failing MHA when bias was NoneType
def test_bias_is_none(self):
x = torch.rand((1, 5, 10))
model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
model.eval()
model(x, x, x)
# completes without error


# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.
instantiate_parametrized_tests(TestTransformers)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/activation.py
Expand Up @@ -1118,7 +1118,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
if not why_not_fast_path:
Expand Down