Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime Error raised by torch.nn.modules.activation.MultiheadAttention when bias=False, batch_first=True #88669

Closed
shakedbr opened this issue Nov 8, 2022 · 9 comments
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@shakedbr
Copy link

shakedbr commented Nov 8, 2022

馃悰 Describe the bug

Hi,

When creating an object of torch.nn.modules.activation.MultiheadAttention with bias=False and batch_first=True, activating evaluation mode, and calling the forward pass you get an exception:

import torch

x = torch.rand((1, 5, 10))
model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
model.eval()
model(x, x, x)
Traceback (most recent call last):
  File "/Users/test.py", line 376, in <module>
    model(x,x,x)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/py39/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1107, in forward
    elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
  File "/opt/homebrew/Caskroom/miniforge/base/envs/py39/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1107, in <listcomp>
    elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
AttributeError: 'NoneType' object has no attribute 'is_cuda'

It seems that the following lines don't handle the case where a parameter is None.

elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):

Versions

[pip3] numpy==1.23.4
[pip3] torch==1.12.1
[pip3] torch-scatter==2.0.9
[pip3] torchaudio==0.12.1
[pip3] torchvision==0.2.2
[conda] numpy 1.23.4 py39hefdcf20_0 conda-forge
[conda] pytorch 1.12.1 py3.9_0 pytorch
[conda] torch-scatter 2.0.9 pypi_0 pypi
[conda] torchaudio 0.12.1 py39_cpu pytorch
[conda] torchvision 0.2.2 py_3 pytorch

cc @jbschlosser @bhosmer @cpuhrsch @erichan1

@shakedbr shakedbr changed the title Runtime Error raised by torch.nn.modules.activation.MultiheadAttention when bias=True, batch_first=True Runtime Error raised by torch.nn.modules.activation.MultiheadAttention when bias=False, batch_first=True Nov 8, 2022
@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: transformer/mha Issues related to Transformers and MultiheadAttention labels Nov 10, 2022
@cpuhrsch
Copy link
Contributor

Thank you for opening the issue @shakedbr - does this issue persist with newer versions of PyTorch or nightlies?

@mikekgfb mikekgfb added this to the 1.13.1 milestone Nov 11, 2022
@shakedbr
Copy link
Author

@cpuhrsch, this also happens in version 1.13.0 and in a nightly version 1.14.0.dev20221113, but in these versions, to reproduce the bug you need to include an even number of heads e.g.:

import torch

x = torch.rand((1, 5, 10))
model = torch.nn.modules.activation.MultiheadAttention(10, num_heads=2, bias=False, batch_first=True)
model.eval()
model(x, x, x)

@cpuhrsch
Copy link
Contributor

Thanks for checking @shakedbr - @mikekgfb has sent a fix and it looks like it'll be included in 1.13.1.

@malfet
Copy link
Contributor

malfet commented Nov 28, 2022

@mikekgfb can you please link the fix to the issue?

@weiwangmeta
Copy link
Contributor

@weiwangmeta
Copy link
Contributor

Can this be closed given the above PR? cc @mikekgfb @atalman @malfet

@mikekgfb
Copy link
Contributor

mikekgfb commented Dec 2, 2022

Also needs #88854

@weiwangmeta
Copy link
Contributor

Also needs #88854

Thank you Michael. #89855 (comment) is the cherry-pick to release/1.13

@atalman
Copy link
Contributor

atalman commented Dec 13, 2022

closing since cherry-pick is included in the release

@atalman atalman closed this as completed Dec 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants
@cpuhrsch @malfet @albanD @atalman @shakedbr @mikekgfb @weiwangmeta and others