Skip to content

Commit

Permalink
Add changes for conversion script for pretrained window size case
Browse files Browse the repository at this point in the history
  • Loading branch information
nandwalritik committed Jun 16, 2022
1 parent e51c3f1 commit 2547f45
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ def get_swinv2_config(swinv2_name):
name_split = swinv2_name.split("_")

model_size = name_split[1]
img_size = int(name_split[4])
window_size = int(name_split[3][6:])

if "to" in name_split[4]:
img_size = int(name_split[4][-3:])
else:
img_size = int(name_split[4])
if "to" in name_split[3]:
window_size = int(name_split[3][-2:])
else:
window_size = int(name_split[3][6:])
if model_size == "tiny":
embed_dim = 96
depths = (2, 2, 6, 2)
Expand All @@ -34,8 +39,11 @@ def get_swinv2_config(swinv2_name):
embed_dim = 192
depths = (2, 2, 18, 2)
num_heads = (6, 12, 24, 48)

if "to" in swinv2_name:
config.pretrained_window_sizes = (12,12,12,6)

if "22k" in swinv2_name:
if ("22k" in swinv2_name) and ("to" not in swinv2_name):
num_classes = 21841
else:
num_classes = 1000
Expand Down

0 comments on commit 2547f45

Please sign in to comment.