Skip to content

Commit

Permalink
Swin support
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed May 11, 2022
1 parent ed7b409 commit e78ee26
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/swin/modeling_swin.py
Expand Up @@ -226,7 +226,7 @@ def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
Expand Down Expand Up @@ -433,7 +433,7 @@ def __init__(self, config, dim, num_heads):

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
Expand Down Expand Up @@ -486,7 +486,7 @@ def forward(
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

Expand Down Expand Up @@ -1051,7 +1051,7 @@ def forward(
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output.transpose(1, 2)
batch_size, num_channels, sequence_length = sequence_output.shape
height = width = int(sequence_length**0.5)
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)

# Reconstruct pixel values
Expand Down
41 changes: 33 additions & 8 deletions src/transformers/utils/fx.py
Expand Up @@ -18,9 +18,9 @@
import functools
import inspect
import math
import operator
import random
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union

import torch
Expand Down Expand Up @@ -104,6 +104,7 @@ def _generate_supported_model_classes(
"t5",
"roberta",
"vit",
"swin",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# "layoutlm",
# "xlnet",
Expand Down Expand Up @@ -280,6 +281,10 @@ def torch_tensor_index_select(self, dim, index):
return torch_tensor_index_select(self, dim, index)


def torch_roll(input, shifts, dims=None):
return input


def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
Expand Down Expand Up @@ -325,6 +330,21 @@ def torch_nn_bcewithlogitsloss(self, input, target):
return torch.empty(shape, device="meta")


def torch_tensor_getitem(self, indices):
if not isinstance(self, torch.Tensor):
return operator.getitem(self, indices)
if not isinstance(indices, (tuple, list)):
indices = [indices]

def map_fn(x):
if isinstance(x, torch.Tensor):
return torch.zeros_like(x, device="cpu")
return x

indices = list(map(map_fn, indices))
return torch.empty_like(torch.empty_like(self, device="cpu")[indices], device="meta")


_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: embedding_override,
torch.nn.LayerNorm: torch_nn_layernorm_override,
Expand All @@ -342,13 +362,15 @@ def torch_nn_bcewithlogitsloss(self, input, target):
torch.Tensor.mul: torch_tensor_mul_override,
torch.matmul: torch_matmul_override,
torch.Tensor.repeat: torch_tensor_repeat_override,
torch.roll: torch_roll,
# TODO: those might not be needed.
# torch.index_select: torch_index_select,
# torch.Tensor.index_select: torch_tensor_index_select,
torch.nn.Conv2d: torch_nn_conv2d,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
# operator.getitem: torch_tensor_getitem,
}


Expand Down Expand Up @@ -394,6 +416,9 @@ def __getattr__(self, k):
# we peephole optimize to the method invocation
return HFAttribute(self, k)

def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_method", "__setitem__", (self, indices, values), {})

def __contains__(self, key):
# To handle cases such as :
# `"some_key" in kwargs`
Expand Down Expand Up @@ -795,13 +820,13 @@ def symbolic_trace(
tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)

# Copy all the original attributes to the traced GraphModule.
regular_module_attributes = dir(nn.Module())
for name in dir(model):
attr = getattr(model, name)
if name.startswith("_") or name in regular_module_attributes:
continue
setattr(traced, name, deepcopy(attr))
# from copy import deepcopy
# regular_module_attributes = dir(nn.Module())
# for name in dir(model):
# attr = getattr(model, name)
# if name.startswith("_") or name in regular_module_attributes:
# continue
# setattr(traced, name, deepcopy(attr))

return traced
1 change: 1 addition & 0 deletions tests/swin/test_modeling_swin.py
Expand Up @@ -175,6 +175,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_compatible = True

test_pruning = False
test_resize_embeddings = False
Expand Down

0 comments on commit e78ee26

Please sign in to comment.