-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT Neox rotary embedding does not work with padding left #22161
Comments
Hey thanks for reporting! |
It is possible that this is not the root cause but there is an issue with these lines: offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
|
We use padding left extensively on the serving side as we have a dynamic batching logic that batches sequence of very different lengths together. While the pad==256 example above seems extreme in isolation, it is completely normal when serving. We sometimes even go higher in chat applications where a member of the batch has a very large history (> 1000 tokens) and other sequences only just started ( ~ 40 tokens). We also serve all the models in bfloat16 if available and we almost always use sampling which amplifies the logits issue even more. |
Hey everyone! Yes, it is correct, it is pretty much the same issue as I reported here -- we should be passing We have an open PR to fix the same issue with GPT-J (#22069), I'll make sure it is ported to GPT NeoX when it is merged. We are currently ironing out |
Hi @OlivierDehaene, I'm actually in the middle of porting the fix from #22069 to GPT-Neox too, since I was also interested in that one (in parallel with other things including resolving this torch.fx issue). Also for reference there's a similar existing issue which went stale: #18999 |
Hi @njhill! class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None:
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
cos = self.cos_cached[position_ids].unsqueeze(1)
sin = self.sin_cached[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
1 similar comment
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
transformers
version: 4.26.1Who can help?
@ArthurZucker, @younesbelkada, @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The problem is exacerbated in bfloat16
Expected behavior
padding left should have no influence on the resulting logits.
While the differences do not look like much, it has a huge impact on generation.
The text was updated successfully, but these errors were encountered: