Skip to content

Commit

Permalink
Fix position embeddings for GPT-J and CodeGen (#22069)
Browse files Browse the repository at this point in the history
* Revert "[GPT-J] add deprecation warning (#21869)"

This reverts commit fb76994.

* Fix position embeddings for GPT-J and CodeGen

* Address review comments from @gante

* Fix "Copied from" comment referencing wrong function

* Fix copy/paste mistake

* Fix training path

* Hopefully make torch.fx happy

* Move position_ids long cast

* Revert "Hopefully make torch.fx happy"

This reverts commit e41a6f4.

* Changes to help with torch.fx tracing

* Linter fix

* Correct position_ids tensor type hint

* Work-around torch.fx tracing issue

* Get the changes to work with torch.fx

* Address review comment from @michaelbenayoun

* Another small adjustment

* Add explanatory comment; small code tidyup
  • Loading branch information
njhill committed Mar 22, 2023
1 parent 8e6c34b commit 4e94c6c
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 136 deletions.
78 changes: 32 additions & 46 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,26 @@
]


# Copied from transformers.models.gptj.modeling_gptj.fixed_pos_embedding
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
# Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)


# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
def rotate_every_two(x):
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')


# Copied from transformers.models.gptj.modeling_gptj.duplicate_interleave
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = (duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :] for t in sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
return (tensor * cos) + (rotate_every_two(tensor) * sin)


class CodeGenAttention(nn.Module):
Expand Down Expand Up @@ -117,9 +100,9 @@ def __init__(self, config):
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = None
if config.rotary_dim is not None:
self.rotary_dim = config.rotary_dim
self.rotary_dim = config.rotary_dim
pos_embd_dim = self.rotary_dim or self.embed_dim
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

def _split_heads(self, x, n_head, dim_head, mp_num):
reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
Expand Down Expand Up @@ -183,8 +166,9 @@ def _attn(
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
Expand All @@ -205,12 +189,13 @@ def forward(
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = value.permute(0, 2, 1, 3)

seq_len = key.shape[1]
offset = 0
embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions

if layer_past is not None:
offset = layer_past[0].shape[-2]
seq_len += offset
sincos = embed_positions[position_ids]
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
Expand All @@ -219,16 +204,14 @@ def forward(
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
Expand Down Expand Up @@ -292,16 +275,18 @@ def forward(
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -488,7 +473,7 @@ def forward(
token_type_ids = token_type_ids.view(-1, input_shape[-1])

if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
position_ids = position_ids.view(-1, input_shape[-1]).long()

if past_key_values is None:
past_length = 0
Expand Down Expand Up @@ -568,13 +553,15 @@ def custom_forward(*inputs):
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states,
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -645,8 +632,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
Expand Down

0 comments on commit 4e94c6c

Please sign in to comment.