Skip to content

Commit

Permalink
[fbsync] Move out the pad operation from PatchMerging in swin transfo…
Browse files Browse the repository at this point in the history
…rmer to make it fx compatible (#6252)

Reviewed By: jdsgomes

Differential Revision: D37993420

fbshipit-source-id: 6b9dd3e161e74a00ce479a5c42376463f38a844a
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jul 21, 2022
1 parent a5ab6d7 commit f113498
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions torchvision/models/swin_transformer.py
Expand Up @@ -25,6 +25,15 @@
]


def _patch_merging_pad(x):
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
return x


torch.fx.wrap("_patch_merging_pad")


class PatchMerging(nn.Module):
"""Patch Merging Layer.
Args:
Expand All @@ -46,8 +55,7 @@ def forward(self, x: Tensor):
Returns:
Tensor with layout of [..., H/2, W/2, 2*C]
"""
H, W, _ = x.shape[-3:]
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x = _patch_merging_pad(x)

x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
Expand Down

0 comments on commit f113498

Please sign in to comment.