Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.MultiheadAttention with thunder.jit error #287

Closed
Fuzzkatt opened this issue Apr 26, 2024 · 1 comment · Fixed by #319
Closed

torch.nn.MultiheadAttention with thunder.jit error #287

Fuzzkatt opened this issue Apr 26, 2024 · 1 comment · Fixed by #319
Labels
bug Something isn't working

Comments

@Fuzzkatt
Copy link
Collaborator

I'm trying to trace torch.nn.MultiheadAttention with thunder and I'm hitting a AttributeError: The torch language context has no method or attribute is_nested error. Taking a closer look, it's coming from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1236-L1238, in which mha needs to check that the input tensors are not nested. If I comment out these lines, the rest of mha actually works fine, and thunder is able to trace and run both fwd_only and with bwd.

Minimal repro:

import torch
import thunder

model = torch.nn.MultiheadAttention(128, 16, device='cuda')
q = torch.randn(100, 128, device='cuda')
k = torch.randn(100, 128, device='cuda')
v = torch.randn(100, 128, device='cuda')
jfunc = thunder.jit(model)
jfunc(q, k, v)
@Fuzzkatt Fuzzkatt added the bug Something isn't working label Apr 26, 2024
@mruberry
Copy link
Collaborator

triage review –

following #93 (comment), we should just set this to False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants