Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LitGPTSDPABenchmark runs incorrect configs #317

Closed
vedaanta opened this issue Apr 30, 2024 · 4 comments · Fixed by #378
Closed

LitGPTSDPABenchmark runs incorrect configs #317

vedaanta opened this issue Apr 30, 2024 · 4 comments · Fixed by #378
Assignees
Labels

Comments

@vedaanta
Copy link
Collaborator

vedaanta commented Apr 30, 2024

🐛 Bug

LitGPTSDPABenchmark, located here, does not correctly run the litgpt config provided to it.
This is due using NanoGPTConfig underneath, located here.

This creates multiple issues when benchmarking, as incorrect sdpa operation is launched. For example:

  1. litgpt never has dropout ON. But as NanoGPTConfig has default dropout of 0.1, all sdpa operations have dropout ON.
  2. Running GQA sizes, with config.n_query_groups, it is not propagated into the model.

To Reproduce

from thunder.benchmarks import LitGPTSDPABenchmark, Benchmark
import thunder
import torch

bench: Benchmark = LitGPTSDPABenchmark(
        config="Llama-2-7b-hf", device="cuda:0", dtype=thunder.bfloat16, requires_grad=True
    )

args, kwargs = bench.make_batch()
torch.cuda.synchronize()
fn = bench.fn()

jfn = thunder.jit(fn)
jfn(*args, **kwargs)

print(thunder.last_traces(jfn)[-1])

Outputs:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(q, k, v):
  # q: "cuda:0 bf16[16, 32, 4096, 128]"
  # k: "cuda:0 bf16[16, 32, 4096, 128]"
  # v: "cuda:0 bf16[16, 32, 4096, 128]"
  (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(q, k, v, 0.1, True, scale=None)
  return {'output': t0, 'flat_args': [q, k, v], 'flat_output': (t0,)}, ((k, q, t0, t1, t2, t3, t4, t5, v), (True, 0.1))

Notice the is_causal=True, and dropout_p=0.1.
litgpt model has no code path to enable both of them. code

Expected behavior

LitGPTSDPABenchmark runs the exact parameters specified by litgpt configs.

cc @crcrpar @carmocca

@mruberry
Copy link
Collaborator

mruberry commented May 6, 2024

triage review — @riccardofelluga, is this something you'd like to look at?

@riccardofelluga
Copy link
Collaborator

Sure! If it's not urgent I can give it a look as soon as I have a minute

@carmocca
Copy link
Contributor

carmocca commented May 7, 2024

I can help here too

@carmocca
Copy link
Contributor

carmocca commented May 7, 2024

Running GQA sizes, with config.n_query_groups, it is not propagated into the model.

Note that litgpt expands this dimension during training for flash attention support https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L226-L231. So n_query_groups argument doesn't impact the SDPA call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants