From e78ee26547d346e8bf9fb0a506f951dc6b4f6200 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 11 May 2022 16:30:42 +0200 Subject: [PATCH] Swin support --- src/transformers/models/swin/modeling_swin.py | 8 ++-- src/transformers/utils/fx.py | 41 +++++++++++++++---- tests/swin/test_modeling_swin.py | 1 + 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 51a19ab73b8ccb..1d9b7aa6ddb910 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width): """ Merges windows to produce higher resolution features. """ - batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) + batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) return windows @@ -433,7 +433,7 @@ def __init__(self, config, dim, num_heads): def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) + x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( @@ -486,7 +486,7 @@ def forward( context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -1051,7 +1051,7 @@ def forward( # Reshape to (batch_size, num_channels, height, width) sequence_output = sequence_output.transpose(1, 2) batch_size, num_channels, sequence_length = sequence_output.shape - height = width = int(sequence_length**0.5) + height = width = math.floor(sequence_length**0.5) sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) # Reconstruct pixel values diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index bff49d87b7344e..f16c48f7ab5f11 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -18,9 +18,9 @@ import functools import inspect import math +import operator import random import warnings -from copy import deepcopy from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch @@ -104,6 +104,7 @@ def _generate_supported_model_classes( "t5", "roberta", "vit", + "swin", # TODO: add support for them as it should be quite easy to do so (small blocking issues). # "layoutlm", # "xlnet", @@ -280,6 +281,10 @@ def torch_tensor_index_select(self, dim, index): return torch_tensor_index_select(self, dim, index) +def torch_roll(input, shifts, dims=None): + return input + + def torch_nn_conv2d(self, input): h_in, w_in = input.shape[-2:] shape = None @@ -325,6 +330,21 @@ def torch_nn_bcewithlogitsloss(self, input, target): return torch.empty(shape, device="meta") +def torch_tensor_getitem(self, indices): + if not isinstance(self, torch.Tensor): + return operator.getitem(self, indices) + if not isinstance(indices, (tuple, list)): + indices = [indices] + + def map_fn(x): + if isinstance(x, torch.Tensor): + return torch.zeros_like(x, device="cpu") + return x + + indices = list(map(map_fn, indices)) + return torch.empty_like(torch.empty_like(self, device="cpu")[indices], device="meta") + + _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { torch.nn.Embedding: embedding_override, torch.nn.LayerNorm: torch_nn_layernorm_override, @@ -342,6 +362,7 @@ def torch_nn_bcewithlogitsloss(self, input, target): torch.Tensor.mul: torch_tensor_mul_override, torch.matmul: torch_matmul_override, torch.Tensor.repeat: torch_tensor_repeat_override, + torch.roll: torch_roll, # TODO: those might not be needed. # torch.index_select: torch_index_select, # torch.Tensor.index_select: torch_tensor_index_select, @@ -349,6 +370,7 @@ def torch_nn_bcewithlogitsloss(self, input, target): torch.nn.MSELoss: torch_nn_mseloss, torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss, torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss, + # operator.getitem: torch_tensor_getitem, } @@ -394,6 +416,9 @@ def __getattr__(self, k): # we peephole optimize to the method invocation return HFAttribute(self, k) + def __setitem__(self, indices, values): + return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {}) + def __contains__(self, key): # To handle cases such as : # `"some_key" in kwargs` @@ -795,13 +820,13 @@ def symbolic_trace( tracer = HFTracer() traced_graph = tracer.trace(model, concrete_args=concrete_args) traced = torch.fx.GraphModule(model, traced_graph) - # Copy all the original attributes to the traced GraphModule. - regular_module_attributes = dir(nn.Module()) - for name in dir(model): - attr = getattr(model, name) - if name.startswith("_") or name in regular_module_attributes: - continue - setattr(traced, name, deepcopy(attr)) + # from copy import deepcopy + # regular_module_attributes = dir(nn.Module()) + # for name in dir(model): + # attr = getattr(model, name) + # if name.startswith("_") or name in regular_module_attributes: + # continue + # setattr(traced, name, deepcopy(attr)) return traced diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 2147f578e73ea0..e286a002e1f920 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + fx_compatible = True test_pruning = False test_resize_embeddings = False