Skip to content

Commit

Permalink
Get the changes to work with torch.fx
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Mar 19, 2023
1 parent 15d1d5e commit 03056aa
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
28 changes: 8 additions & 20 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_codegen import CodeGenConfig


Expand Down Expand Up @@ -73,8 +67,7 @@ def rotate_every_two(x: torch.Tensor) -> torch.Tensor:


# Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
def apply_rotary_pos_emb(tensor: torch.Tensor, sincos: torch.Tensor) -> torch.Tensor:
sin, cos = sincos
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
return (tensor * cos) + (rotate_every_two(tensor) * sin)
Expand Down Expand Up @@ -201,13 +194,8 @@ def forward(
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions

if is_torch_fx_proxy(position_ids):
# Assume no padding in torch.fx case, index-by-tensor can't be traced
sincos = embed_positions[None, : position_ids.shape[-1], :].repeat(position_ids.shape[0], 1, 1)
else:
sincos = embed_positions[position_ids]

sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sincos = embed_positions[position_ids]
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
Expand All @@ -216,14 +204,14 @@ def forward(
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sincos)
q_rot = apply_rotary_pos_emb(q_rot, sincos)
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sincos)
query = apply_rotary_pos_emb(query, sincos)
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
Expand Down
37 changes: 22 additions & 15 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)


@torch.fx.wrap
def get_embed_positions(embed_positions, position_ids):
return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)


def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')


def apply_rotary_pos_emb(tensor: torch.Tensor, sincos: torch.Tensor) -> torch.Tensor:
sin, cos = sincos
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
return (tensor * cos) + (rotate_every_two(tensor) * sin)
Expand Down Expand Up @@ -178,6 +182,13 @@ def _attn(

return attn_output, attn_weights

def _get_embed_positions(self, position_ids):
embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions
return embed_positions.repeat(position_ids.shape[0], 1, 1)

def forward(
self,
hidden_states: Optional[torch.FloatTensor],
Expand All @@ -199,18 +210,14 @@ def forward(
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

embed_positions = self.embed_positions
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
self.embed_positions = embed_positions

if is_torch_fx_proxy(position_ids):
# Assume no padding in torch.fx case, index-by-tensor can't be traced
sincos = embed_positions[None, : position_ids.shape[-1], :].repeat(position_ids.shape[0], 1, 1)
embed_positions = get_embed_positions(self.embed_positions, position_ids)
else:
sincos = embed_positions[position_ids]
embed_positions = self._get_embed_positions(position_ids)

sincos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
Expand All @@ -219,14 +226,14 @@ def forward(
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]

k_rot = apply_rotary_pos_emb(k_rot, sincos)
q_rot = apply_rotary_pos_emb(q_rot, sincos)
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, sincos)
query = apply_rotary_pos_emb(query, sincos)
key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos)

key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,26 @@ def torch_tensor_repeat(self, *sizes):
return torch.empty(shape, device="meta")


def torch_repeat_interleave(*args, dim=None, output_size=None):
num_args = len(args)
if num_args == 1:
shape = [output_size if output_size is not None else args[0].sum()]
else:
shape = list(args[0].shape)
repeats = args[1]
if dim is None:
if len(args) > 2:
dim = args[2]
else:
shape = [sum(shape)]
dim = 0
if type(repeats) == int or torch.numel(repeats) == 1:
shape[dim] *= int(repeats)
else:
shape[dim] = output_size if output_size is not None else args[1].sum()
return torch.empty(*shape, device="meta")


def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
Expand All @@ -371,6 +391,16 @@ def torch_tensor_index_select(self, dim, index):
return torch_index_select(self, dim, index)


def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
shape = list(input.shape)
shape[dim] = index.shape[dim]
return torch.empty(*shape, device="meta")


def torch_tensor_gather(self, dim, index):
return torch_gather(self, dim, index)


def torch_roll(input, shifts, dims=None):
return input

Expand Down Expand Up @@ -537,11 +567,14 @@ def to_concrete(t):
torch.Tensor.baddbmm: torch_tensor_baddbmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.repeat_interleave: torch_repeat_interleave,
torch.roll: torch_roll,
torch.flip: torch_flip,
torch.Tensor.flip: torch_tensor_flip,
torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.gather: torch_gather,
torch.Tensor.gather: torch_tensor_gather,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.squeeze: torch_squeeze,
Expand Down

0 comments on commit 03056aa

Please sign in to comment.