From 1f88b208acab2cf974849c9161d24f08486f592c Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 15 Nov 2022 01:25:17 +0000 Subject: [PATCH] Fix cuda/cpu check on NoneType (Unit test) (#88970) Summary: Fix cuda/cpu check on NoneType (unit test) Test Plan: sabdcastle/ github CI/CD Differential Revision: D41208798 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88970 Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch --- test/test_transformers.py | 9 +++++++++ torch/nn/modules/activation.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index c86b89bed5efd0c..93a94a5604c919e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1168,6 +1168,15 @@ def make_tensor(*size, device=device, dtype=dtype): self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, torch.ones_like(q), 0.0, False, False)) + # Test failing MHA when bias was NoneType + def test_bias_is_none(self): + x = torch.rand((1, 5, 10)) + model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) + model.eval() + model(x, x, x) + # completes without error + + # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index e6b3b778e5fbcb7..b00da06126a7a99 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1113,7 +1113,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "some Tensor argument has_torch_function" elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") if not why_not_fast_path: