Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT Neox rotary embedding does not work with padding left #22161

Closed
2 of 4 tasks
OlivierDehaene opened this issue Mar 14, 2023 · 8 comments
Closed
2 of 4 tasks

GPT Neox rotary embedding does not work with padding left #22161

OlivierDehaene opened this issue Mar 14, 2023 · 8 comments
Assignees

Comments

@OlivierDehaene
Copy link
Member

OlivierDehaene commented Mar 14, 2023

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.4.0-1097-aws-x86_64-with-glibc2.27
  • Python version: 3.10.9
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker, @younesbelkada, @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")

model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto")

f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))

torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])

# AssertionError: Tensor-likes are not close!

# Mismatched elements: 6057 / 50288 (12.0%)
# Greatest absolute difference: 0.0003177821636199951 at index (0, 4649) (up to 1e-05 allowed)
# Greatest relative difference: 1.5682868874196898 at index (0, 30410) (up to 0.0001 allowed)

The problem is exacerbated in bfloat16

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", padding_side="left")

model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b", device_map="auto", torch_dtype=torch.bfloat16)

f_not_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=False, return_tensors="pt"))
f_padded = model.forward(**tokenizer(["<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"], padding=True, pad_to_multiple_of=256, return_tensors="pt"))

torch.testing.assert_allclose(f_not_padded.logits[:, -1], f_padded.logits[:, -1])

# AssertionError: Tensor-likes are not equal!

# Mismatched elements: 49417 / 50288 (98.3%)
# Greatest absolute difference: 1.154541015625 at index (0, 50271)
# Greatest relative difference: 2058.906976744186 at index (0, 29917)

Expected behavior

padding left should have no influence on the resulting logits.

While the differences do not look like much, it has a huge impact on generation.

@ArthurZucker
Copy link
Collaborator

Hey thanks for reporting!

@OlivierDehaene
Copy link
Member Author

It is possible that this is not the root cause but there is an issue with these lines:

offset = 0
if has_layer_past:
    offset = layer_past[0].shape[-2]
    seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)

offset and seq_len are not computed correctly when you have padding.
On a sidenote, it is impossible to have a single value for offset as different sequences in the batch might have different length and therefore different offsets when padding left.

@OlivierDehaene
Copy link
Member Author

OlivierDehaene commented Mar 15, 2023

We use padding left extensively on the serving side as we have a dynamic batching logic that batches sequence of very different lengths together.

While the pad==256 example above seems extreme in isolation, it is completely normal when serving. We sometimes even go higher in chat applications where a member of the batch has a very large history (> 1000 tokens) and other sequences only just started ( ~ 40 tokens).

We also serve all the models in bfloat16 if available and we almost always use sampling which amplifies the logits issue even more.

@gante
Copy link
Member

gante commented Mar 15, 2023

Hey everyone! Yes, it is correct, it is pretty much the same issue as I reported here -- we should be passing position_ids all the way down to the attention layer, and compute the sequence length from it.

We have an open PR to fix the same issue with GPT-J (#22069), I'll make sure it is ported to GPT NeoX when it is merged. We are currently ironing out torch.fx issues (adding the correct behavior makes the tensors dynamic, which blocks existing features)

@njhill
Copy link
Contributor

njhill commented Mar 15, 2023

Hi @OlivierDehaene, I'm actually in the middle of porting the fix from #22069 to GPT-Neox too, since I was also interested in that one (in parallel with other things including resolving this torch.fx issue).

Also for reference there's a similar existing issue which went stale: #18999

@OlivierDehaene
Copy link
Member Author

Hi @njhill!
Nice thanks for working on this!
For now I have a fix on my text-generation-inference fork as we have multiple neox in prod and I need a fix asap. It's sensibly the same to yours I think.

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        self.cos_cached = None
        self.sin_cached = None

    @staticmethod
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    @staticmethod
    def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
        t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)

    def forward(self, q, k, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
            if seq_len > self.max_seq_len_cached:
                self.max_seq_len_cached = seq_len
            self.cos_cached, self.sin_cached = self._create_cos_sin(
                self.inv_freq, self.max_seq_len_cached, q.dtype, q.device
            )
        cos = self.cos_cached[position_ids].unsqueeze(1)
        sin = self.sin_cached[position_ids].unsqueeze(1)

        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

1 similar comment
@github-actions
Copy link

github-actions bot commented May 9, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants