We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LitGPTSDPABenchmark, located here, does not correctly run the litgpt config provided to it. This is due using NanoGPTConfig underneath, located here.
LitGPTSDPABenchmark
NanoGPTConfig
This creates multiple issues when benchmarking, as incorrect sdpa operation is launched. For example:
0.1
config.n_query_groups
from thunder.benchmarks import LitGPTSDPABenchmark, Benchmark import thunder import torch bench: Benchmark = LitGPTSDPABenchmark( config="Llama-2-7b-hf", device="cuda:0", dtype=thunder.bfloat16, requires_grad=True ) args, kwargs = bench.make_batch() torch.cuda.synchronize() fn = bench.fn() jfn = thunder.jit(fn) jfn(*args, **kwargs) print(thunder.last_traces(jfn)[-1])
Outputs:
# Constructed by Delete Last Used (took 0 milliseconds) import torch from thunder.executors.torchex import no_autocast @torch.no_grad() @no_autocast def augmented_forward_fn(q, k, v): # q: "cuda:0 bf16[16, 32, 4096, 128]" # k: "cuda:0 bf16[16, 32, 4096, 128]" # v: "cuda:0 bf16[16, 32, 4096, 128]" (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(q, k, v, 0.1, True, scale=None) return {'output': t0, 'flat_args': [q, k, v], 'flat_output': (t0,)}, ((k, q, t0, t1, t2, t3, t4, t5, v), (True, 0.1))
Notice the is_causal=True, and dropout_p=0.1. litgpt model has no code path to enable both of them. code
is_causal=True
dropout_p=0.1
LitGPTSDPABenchmark runs the exact parameters specified by litgpt configs.
cc @crcrpar @carmocca
The text was updated successfully, but these errors were encountered:
triage review — @riccardofelluga, is this something you'd like to look at?
Sorry, something went wrong.
Sure! If it's not urgent I can give it a look as soon as I have a minute
I can help here too
Running GQA sizes, with config.n_query_groups, it is not propagated into the model.
Note that litgpt expands this dimension during training for flash attention support https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py#L226-L231. So n_query_groups argument doesn't impact the SDPA call.
n_query_groups
carmocca
Successfully merging a pull request may close this issue.
🐛 Bug
LitGPTSDPABenchmark
, located here, does not correctly run the litgpt config provided to it.This is due using
NanoGPTConfig
underneath, located here.This creates multiple issues when benchmarking, as incorrect sdpa operation is launched. For example:
NanoGPTConfig
has default dropout of0.1
, all sdpa operations have dropout ON.config.n_query_groups
, it is not propagated into the model.To Reproduce
Outputs:
Notice the
is_causal=True
, anddropout_p=0.1
.litgpt model has no code path to enable both of them. code
Expected behavior
LitGPTSDPABenchmark
runs the exact parameters specified by litgpt configs.cc @crcrpar @carmocca
The text was updated successfully, but these errors were encountered: