Skip to content

Commit

Permalink
Move out the pad operation from PatchMerging in swin transformer to m…
Browse files Browse the repository at this point in the history
…ake it fx compatible (pytorch#6252)
  • Loading branch information
YosuaMichael authored and NicolasHug committed Jul 21, 2022
1 parent da3794e commit 4bdd088
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchvision/models/swin_transformer.py
Expand Up @@ -25,6 +25,15 @@
]


def _patch_merging_pad(x):
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
return x


torch.fx.wrap("_patch_merging_pad")


class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
Expand All @@ -46,8 +55,7 @@ def forward(self, x: Tensor):
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x = _patch_merging_pad(x)

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
Expand Down

0 comments on commit 4bdd088

Please sign in to comment.