Skip to content

Commit

Permalink
Hopefully make torch.fx happy
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Mar 11, 2023
1 parent fc54502 commit e41a6f4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
16 changes: 14 additions & 2 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,27 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:


# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
def rotate_every_two(x):
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')


# Copied from transformers.models.gptj.modeling_gptj.duplicate_interleave
def duplicate_interleave(m: torch.Tensor) -> torch.Tensor:
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dims = m.shape[:-1]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
return m.view(*dims, -1) # reshape into a matrix, interleaving the copy


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor: torch.Tensor, sincos: torch.Tensor) -> torch.Tensor:
sin, cos = (torch.repeat_interleave(t[:, :, None, :], 2, 3) for t in sincos)
sin, cos = (duplicate_interleave(t)[:, :, None, :] for t in sincos)
return (tensor * cos) + (rotate_every_two(tensor) * sin)


Expand Down Expand Up @@ -195,6 +206,7 @@ def forward(

sincos = embed_positions[position_ids.long()]
sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sincos = [t.contiguous() for t in sincos]

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,25 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)


def rotate_every_two(x):
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')


def duplicate_interleave(m: torch.Tensor) -> torch.Tensor:
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dims = m.shape[:-1]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
return m.view(*dims, -1) # reshape into a matrix, interleaving the copy


def apply_rotary_pos_emb(tensor: torch.Tensor, sincos: torch.Tensor) -> torch.Tensor:
sin, cos = (torch.repeat_interleave(t[:, :, None, :], 2, 3) for t in sincos)
sin, cos = (duplicate_interleave(t)[:, :, None, :] for t in sincos)
return (tensor * cos) + (rotate_every_two(tensor) * sin)


Expand Down Expand Up @@ -198,6 +208,7 @@ def forward(

sincos = embed_positions[position_ids.long()]
sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sincos = [t.contiguous() for t in sincos]

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
Expand Down

0 comments on commit e41a6f4

Please sign in to comment.