From 4bdd08846c58ef1e9ce9db4024c1a44dcaf24e5d Mon Sep 17 00:00:00 2001 From: YosuaMichael Date: Fri, 8 Jul 2022 18:38:55 +0100 Subject: [PATCH] Move out the pad operation from PatchMerging in swin transformer to make it fx compatible (#6252) --- torchvision/models/swin_transformer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchvision/models/swin_transformer.py b/torchvision/models/swin_transformer.py index 2f2cfd44445..f61dfb6154e 100644 --- a/torchvision/models/swin_transformer.py +++ b/torchvision/models/swin_transformer.py @@ -25,6 +25,15 @@ ] +def _patch_merging_pad(x): + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + return x + + +torch.fx.wrap("_patch_merging_pad") + + class PatchMerging(nn.Module): """Patch Merging Layer. Args: @@ -46,8 +55,7 @@ def forward(self, x: Tensor): Returns: Tensor with layout of [..., H/2, W/2, 2*C] """ - H, W, _ = x.shape[-3:] - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x = _patch_merging_pad(x) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C