diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 666c31874f9a30b..c4ad42e11606565 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1116,7 +1116,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the "