Skip to content

Commit

Permalink
Rename base_model_prefix to vit
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Apr 20, 2022
1 parent d2646cd commit 960d324
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
20 changes: 10 additions & 10 deletions src/transformers/models/yolos/convert_yolos_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def read_in_q_k_v(state_dict, config, base_model=False):

def rename_key(name):
if "backbone" in name:
name = name.replace("backbone", "yolos")
name = name.replace("backbone", "vit")
if "cls_token" in name:
name = name.replace("cls_token", "embeddings.cls_token")
if "det_token" in name:
Expand Down Expand Up @@ -117,8 +117,8 @@ def rename_key(name):
name = name.replace("class_embed", "class_labels_classifier")
if "bbox_embed" in name:
name = name.replace("bbox_embed", "bbox_predictor")
if "yolos.norm" in name:
name = name.replace("yolos.norm", "yolos.layernorm")
if "vit.norm" in name:
name = name.replace("vit.norm", "vit.layernorm")

return name

Expand All @@ -130,17 +130,17 @@ def convert_state_dict(orig_state_dict, model):
if "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[2])
dim = model.yolos.encoder.layer[layer_num].attention.attention.all_head_size
dim = model.vit.encoder.layer[layer_num].attention.attention.all_head_size
if "weight" in key:
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.query.weight"] = val[:dim, :]
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.key.weight"] = val[
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.weight"] = val[:dim, :]
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.weight"] = val[
dim : dim * 2, :
]
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
else:
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.query.bias"] = val[:dim]
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
orig_state_dict[f"yolos.encoder.layer.{layer_num}.attention.attention.value.bias"] = val[-dim:]
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.bias"] = val[:dim]
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.bias"] = val[-dim:]
else:
orig_state_dict[rename_key(key)] = val

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,8 @@ class YolosForObjectDetection(YolosPreTrainedModel):
def __init__(self, config: YolosConfig):
super().__init__(config)

# YOLOS encoder model
self.yolos = YolosModel(config, add_pooling_layer=False)
# YOLOS (ViT) encoder model
self.vit = YolosModel(config, add_pooling_layer=False)

# Object detection heads
# We add one for the "no object" class
Expand Down Expand Up @@ -778,7 +778,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# First, sent images through YOLOS base model to obtain hidden states
outputs = self.yolos(
outputs = self.vit(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down

0 comments on commit 960d324

Please sign in to comment.