Skip to content

Commit

Permalink
Support tracing for ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed May 2, 2022
1 parent f275e59 commit ed7b409
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self, config: ViTConfig) -> None:

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -245,7 +245,7 @@ def forward(

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down Expand Up @@ -687,7 +687,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)

# Reconstruct pixel values
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import builtins
import collections
import functools
import inspect
import math
Expand All @@ -31,6 +32,7 @@
CONFIG_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand Down Expand Up @@ -71,6 +73,7 @@ def _generate_supported_model_classes(
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
}

Expand Down Expand Up @@ -100,6 +103,7 @@ def _generate_supported_model_classes(
"gpt_neo",
"t5",
"roberta",
"vit",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
Expand Down Expand Up @@ -276,6 +280,27 @@ def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)


def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")


def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
Expand Down Expand Up @@ -320,6 +345,7 @@ def torch_nn_bcewithlogitsloss(self, input, target):
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
Expand Down Expand Up @@ -521,6 +547,15 @@ def _generate_dummy_input(
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = model.config.image_size
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device
)

elif "mask" in input_name or "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
Expand Down
9 changes: 1 addition & 8 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa
traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)
else:
input_names = ["input_ids", "attention_mask", "token_type_ids"]
input_ids = inputs["input_ids"]
input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]

labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None)
Expand All @@ -756,12 +755,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa

model_output = model(**filtered_inputs)

rank = len(input_ids.shape)
if rank not in [2, 3]:
raise NotImplementedError(
f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}."
)

traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs)

Expand Down
1 change: 1 addition & 0 deletions tests/vit/test_modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True

test_pruning = False
test_resize_embeddings = False
Expand Down

0 comments on commit ed7b409

Please sign in to comment.