Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_pool2d锛歋hape propagation error for input dimensions used by resnet #365

Open
kiya00 opened this issue May 6, 2024 · 2 comments
Open
Assignees
Labels
bug Something isn't working operators

Comments

@kiya00
Copy link
Collaborator

kiya00 commented May 6, 2024

馃悰 Bug

import torch
import torch.nn.functional as F
import thunder

a = torch.randn(1, 64, 112, 112).cuda().requires_grad_()
def func(a):
    return F.max_pool2d(a, 3, 2, 1, 1, False, False)  # t79: "cuda:0 f32[1, 64, 56, 56]"
cfunc = thunder.jit(func)
b = cfunc(a)

print(thunder.last_traces(cfunc)[-1].output[0]['output'].shape)
print(b.shape)

Outputs:

(56, 56)
torch.Size([1, 64, 56, 56])

The output shape of the trace is wrong, but it runs successfully

Trace:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[1, 64, 112, 112]"
  (t0, t1) = max_pool2d_with_indices(a, 3, 2, 1, 1, False)
  return {'output': t0, 'flat_args': [a], 'flat_output': (t0,)}, ((a, t1), (False, 3, 2, 1, 1))

cc @apaz-cli

@nikitaved
Copy link
Contributor

nikitaved commented May 6, 2024

OK, looks like max_pool_with_indices comes from #163. max_pool without indices has a well-tested meta-function, and it could be re-used here.

@mruberry
Copy link
Collaborator

triage review -- we should test that the metadata thunder produces is consistent with the actual output, too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working operators
Projects
None yet
Development

No branches or pull requests

4 participants