Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Help needed] Impact of padding on causal attention ? #901

Open
Optimox opened this issue Apr 30, 2024 · 5 comments
Open

[Help needed] Impact of padding on causal attention ? #901

Optimox opened this issue Apr 30, 2024 · 5 comments

Comments

@Optimox
Copy link
Contributor

Optimox commented Apr 30, 2024

Hi,

I noticed that the batch size during my inference had an impact on my predictions and I can't understand why. I dug a little more and ended up with the following minimal code that makes me nervous about my understanding of causal attention:

from torchtune import utils
from torchtune.models.llama3._component_builders import lora_llama3, llama3
import torch
_device = utils.get_device(device="cuda")

with utils.set_default_dtype(torch.float32), _device :
    llm = llama3(
                                    
                                    vocab_size=128_256,
                                    num_layers=1,
                                    num_heads=32,
                                    num_kv_heads=8,
                                    embed_dim=4096,
                                    max_seq_len=8192,
                                    intermediate_dim=14336,
                                    attn_dropout=0.0,
                                    norm_eps=1e-5,
                                    rope_base=500000.0,
                                   
                        )
    llm = llm.eval()

with torch.no_grad():
    basic_inputs = torch.ones((1, 256)).long().to("cuda")
    longer_ouputs = torch.ones((1, 257)).long().to("cuda")
    single_predict = llm(basic_inputs)
    reproducible_inference = llm(basic_inputs)
    padded_predict = llm(longer_ouputs)
    
    print((single_predict==reproducible_inference).float().mean())
    print((single_predict==padded_predict[:, :256,:]).float().mean())

The first print outputs 1.0 as expected (inference is deterministic).
However the second print statement outputs a very low value (as low as 0.0244).

Does that mean that the last token has an impact on the representation of all the previous tokens? I thought this was not supposed to be the case with Causal LLMs right?

I would really appreciate someone taking the time to let me know what I missed or what I did wrong. Thanks!

@kartikayk
Copy link
Contributor

Hi @Optimox, thanks for filing the issue!

So I think comparing the values exactly is not a great check here. These are large floating point matmuls and so different sizes can result in the output being slightly different (also depends on the hardware). I usually use torch.allclose and keep the tolerances relatively small to run these checks. In this case setting an absolute tolerance of 1e-5 (quite a small value) still returns True, which is what I'd expect. Let me know if this helps.

image

@Optimox
Copy link
Contributor Author

Optimox commented Apr 30, 2024

Hi @kartikayk , thank you for taking some time to help me out.

While I agree that hoping for an exact match is not necessarily reasonable, it seems that the errors propagate a lot and with a trained model for specific tasks I see significant differences in metrics only based on whether there is a padding tail.

Also it seems that with bfloat16 things are even worse (and my real use case uses bfloat16), when running my previous example with bfloat16 then print(torch.isclose(single_predict, padded_predict[:, :256,:], atol=1e-03).float().mean()) gives only 0.626. So more than one third of the values are very different with an absolute tolerance > 1e-3, which could have quite a large impact if you have an extra layer on top of it.

Actually I just wanted to be sure that I did not do anything wrong. In the end we do expect the results to be exactly the same but in practice, due to numerical errors we see discrepancies right? Those numerical errors are larger when using bfloat16 since numerical precision is lower ? The impact would be even bigger if model had been quantized to int8 I guess?

Do you know of a "good practice" for inference then? Is it recommended to pad to training padding length ? Or on the contrary the results would be more precise with predicting samples one by one without padding? (On the very few experiments I have done, it seems that the second option yield significantly better results).

@kartikayk
Copy link
Contributor

This is a great discussion!

I'm digging into this now and will share more observations in a bit.

A few questions:

  • Do you see differences in the quality of the generated output as well?
  • For inference, are you running batched inference and debating whether to make calls at an instance level or not? Or can you share more details on this?

@Optimox
Copy link
Contributor Author

Optimox commented May 2, 2024

Of course I can share a bit more!

My current project is a simple toy project from Kaggle for automated essay scoring: https://www.kaggle.com/competitions/learning-agency-lab-automated-essay-scoring-2/overview.

My approach is quite simple: take an LLM architecture, add a special token (a CLS token) at the end of each essay and use this to make predictions with a specific head based on the representation of this token.

Since each essay can have different length, I must use padding (after CLS token) for batch training or inference. What I noticed is that batch size for inference can make the Quadratic Kappa Weight move from 0.82 to 0.80 for example.

So my minimal code example above is a simpler case of the impact of padding size since I only look at a batch size of 1 (with different padding length). But I guess numerical errors could be even larger when the matrices contains several examples.

Since it is not desirable to have different predictions depending on the order of samples for inference (padding will depend on consecutive examples inside batches) I opted for a batch size of 1 and no padding at all. But this is not ideal in a time constrained framework.

@Optimox
Copy link
Contributor Author

Optimox commented May 3, 2024

To my surprise, it seems that transformers's gemma tokenizer has padding_side set to "left", which has a far larger impact on final predictions since padded tokens interfere with the CLS token.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants