From f5e065b661d54708e37855b9e6370f5d49e662ae Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 12 May 2022 10:42:27 +0200 Subject: [PATCH] ViT and Swin symbolic tracing with torch.fx (#17182) * Support tracing for ViT * Swin support * Fix copies * Fix type annotation issue * Removed unused import --- src/transformers/models/deit/modeling_deit.py | 4 +- src/transformers/models/dpt/modeling_dpt.py | 4 +- .../models/maskformer/modeling_maskformer.py | 6 +- src/transformers/models/swin/modeling_swin.py | 8 +-- src/transformers/models/vit/modeling_vit.py | 6 +- .../models/vit_mae/modeling_vit_mae.py | 4 +- .../models/yolos/modeling_yolos.py | 4 +- src/transformers/utils/fx.py | 58 ++++++++++++++++--- tests/models/swin/test_modeling_swin.py | 1 + tests/models/vit/test_modeling_vit.py | 1 + tests/test_modeling_common.py | 9 +-- 11 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 94bf5dcfbe487..d6fc9d85518bd 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -168,7 +168,7 @@ def __init__(self, config: DeiTConfig) -> None: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -200,7 +200,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 6c5fd2385232c..64ea40a5c534f 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -177,7 +177,7 @@ def __init__(self, config: DPTConfig) -> None: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -209,7 +209,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index bfb020895ce99..0d3538b04fab5 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -496,7 +496,7 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) + batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) return windows @@ -697,7 +697,7 @@ def __init__(self, config, dim, num_heads): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -750,7 +750,7 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 727399f17f4dd..b2d6b348fbaa7 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) + batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) return windows @@ -435,7 +435,7 @@ def __init__(self, config, dim, num_heads): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -488,7 +488,7 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -1071,7 +1071,7 @@ def forward( # Reshape to (batch_size, num_channels, height, width) sequence_output = sequence_output.transpose(1, 2) batch_size, num_channels, sequence_length = sequence_output.shape - height = width = int(sequence_length**0.5) + height = width = math.floor(sequence_length**0.5) sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) # Reconstruct pixel values diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index b2fc044fcb09c..a5fc9a633617d 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -213,7 +213,7 @@ def __init__(self, config: ViTConfig) -> None: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -245,7 +245,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -687,7 +687,7 @@ def forward( # Reshape to (batch_size, num_channels, height, width) sequence_output = sequence_output[:, 1:] batch_size, sequence_length, num_channels = sequence_output.shape - height = width = int(sequence_length**0.5) + height = width = math.floor(sequence_length**0.5) sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) # Reconstruct pixel values diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 473ccd14feb09..f827978739af6 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -342,7 +342,7 @@ def __init__(self, config: ViTMAEConfig) -> None: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -374,7 +374,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 86ef903167d67..578e8ca609279 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -280,7 +280,7 @@ def __init__(self, config: YolosConfig) -> None: def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -312,7 +312,7 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index d112a71516806..83fbee36c34f2 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -14,12 +14,12 @@ # limitations under the License. import builtins +import collections import functools import inspect import math import random import warnings -from copy import deepcopy from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch @@ -31,6 +31,7 @@ CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -71,6 +72,7 @@ def _generate_supported_model_classes( "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, } @@ -100,6 +102,8 @@ def _generate_supported_model_classes( "gpt_neo", "t5", "roberta", + "vit", + "swin", # TODO: add support for them as it should be quite easy to do so (small blocking issues). # "layoutlm", # "xlnet", @@ -276,6 +280,31 @@ def torch_tensor_index_select(self, dim, index): return torch_tensor_index_select(self, dim, index) +def torch_roll(input, shifts, dims=None): + return input + + +def torch_nn_conv2d(self, input): + h_in, w_in = input.shape[-2:] + shape = None + padding = self.padding + if padding == "valid": + padding = (0, 0) + if padding == "same": + shape = list(input.shape) + if shape is None: + shape = list(input.shape) + h_out = math.floor( + (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + shape[-2:] = [h_out, w_out] + shape[-3] = self.out_channels + return torch.empty(shape, device="meta") + + def torch_nn_mseloss(self, input, target): if self.reduction == "none": shape = target.shape @@ -317,9 +346,11 @@ def torch_nn_bcewithlogitsloss(self, input, target): torch.Tensor.mul: torch_tensor_mul_override, torch.matmul: torch_matmul_override, torch.Tensor.repeat: torch_tensor_repeat_override, + torch.roll: torch_roll, # TODO: those might not be needed. # torch.index_select: torch_index_select, # torch.Tensor.index_select: torch_tensor_index_select, + torch.nn.Conv2d: torch_nn_conv2d, torch.nn.MSELoss: torch_nn_mseloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, @@ -368,6 +399,9 @@ def __getattr__(self, k): # we peephole optimize to the method invocation return HFAttribute(self, k) + def __setitem__(self, indices, values): + return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {}) + def __contains__(self, key): # To handle cases such as : # `"some_key" in kwargs` @@ -521,6 +555,15 @@ def _generate_dummy_input( inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: raise NotImplementedError(f"{model_class} not supported yet.") + elif "pixel_values" in input_name: + batch_size = shape[0] + image_size = model.config.image_size + if not isinstance(image_size, collections.abc.Iterable): + image_size = (image_size, image_size) + height, width = image_size + inputs_dict[input_name] = torch.zeros( + batch_size, model.config.num_channels, height, width, dtype=torch.float32, device=device + ) elif "mask" in input_name or "ids" in input_name: inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device) @@ -663,6 +706,11 @@ def trace( else: self.graph.erase_node(node) + # TODO: solves GraphModule creation. + # Without this, return type annotation "Tuple" is causing code execution failure. + if node.op == "output": + node.type = None + return self.graph def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool: @@ -761,12 +809,4 @@ def symbolic_trace( traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) - # Copy all the original attributes to the traced GraphModule. - regular_module_attributes = dir(nn.Module()) - for name in dir(model): - attr = getattr(model, name) - if name.startswith("_") or name in regular_module_attributes: - continue - setattr(traced, name, deepcopy(attr)) - return traced diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index ef7a64e998d7d..c4f73ef360a14 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_compatible = True test_pruning = False test_resize_embeddings = False diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index a1379a9d31ec7..bfca8bf5cb9aa 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -155,6 +155,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_compatible = True test_pruning = False test_resize_embeddings = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ac45a1c10822c..09fd338d3d190 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -738,8 +738,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: - input_names = ["input_ids", "attention_mask", "token_type_ids"] - input_ids = inputs["input_ids"] + input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) @@ -756,12 +755,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa model_output = model(**filtered_inputs) - rank = len(input_ids.shape) - if rank not in [2, 3]: - raise NotImplementedError( - f"symbolic_trace automatic parameters inference not implemented for input of rank {rank}." - ) - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs)