From 18d03331157a8776a51cf7ff08b434ce324d90f9 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Thu, 10 Nov 2022 18:57:08 -0800 Subject: [PATCH] Fix cuda/cpu check on NoneType (#88854) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/88854 Fix cuda/cpu check on NoneType Test Plan: sabdcastle/ github CI/CD Differential Revision: D41203955 fbshipit-source-id: f8d14b3e241a149b35266b134ad622b29379a6d1 --- test/test_transformers.py | 9 +++++++++ torch/nn/modules/activation.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index a9d0d960fb9a603..939d91e7ee87421 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1132,6 +1132,15 @@ def make_tensor(*size, device=device, dtype=dtype): self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, torch.ones_like(q), 0.0, False, False)) + # Test failing MHA when bias was NoneType + def test_bias_is_none(self): + x = torch.rand((1, 5, 10)) + model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) + model.eval() + model(x, x, x) + # completes without error + + # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 5f5615b496d7d05..2960d793e054818 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1114,7 +1114,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the "