Skip to content

Commit

Permalink
Fix gamma parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Feb 3, 2022
1 parent 0293c67 commit 6e71db9
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/source/model_doc/convnext.mdx
Expand Up @@ -30,7 +30,7 @@ and outperforming Swin Transformers on COCO detection and ADE20K segmentation, w

Tips:

- One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
- See the code examples below each model regarding usage.

This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/convnext/configuration_convnext.py
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,7 +40,7 @@ class ConvNextConfig(PretrainedConfig):
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
patch_size (`int`, optional, defaults to 4):
Patch size to use in the stem layer.
Patch size to use in the patch embedding layer.
num_stages (`int`, optional, defaults to 4):
The number of stages in the model.
hidden_sizes (`List[int]`, *optional*, defaults to [96, 192, 384, 768]):
Expand Down
Expand Up @@ -102,7 +102,7 @@ def rename_key(name):
if "norm" in name:
name = name.replace("norm", "layernorm")
if "gamma" in name:
name = name.replace("gamma", "gamma_parameter")
name = name.replace("gamma", "layer_scale_parameter")
if "head" in name:
name = name.replace("head", "classifier")

Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/convnext/modeling_convnext.py
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,7 +164,7 @@ def forward(self, x):


class ConvNextEmbeddings(nn.Module):
"""This class is comparable to (and inspired by) the SwinPatchEmbeddings class
"""This class is comparable to (and inspired by) the SwinEmbeddings class
found in src/transformers/models/swin/modeling_swin.py.
"""

Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(self, config, dim, drop_path=0):
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = ACT2FN[config.hidden_act]
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma_parameter = (
self.layer_scale_parameter = (
nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if config.layer_scale_init_value > 0
else None
Expand All @@ -217,8 +217,8 @@ def forward(self, hidden_states):
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma_parameter is not None:
x = self.gamma_parameter * x
if self.layer_scale_parameter is not None:
x = self.layer_scale_parameter * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
Expand Down
5 changes: 1 addition & 4 deletions tests/test_modeling_convnext.py
Expand Up @@ -335,9 +335,6 @@ def test_inference_image_classification_head(self):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)

print("Predicted class:", model.config.id2label[torch.argmax(outputs.logits, dim=-1).item()])
print("Logits:", outputs.logits[0, :3])

expected_slice = torch.tensor([-0.0750, 0.2478, 0.5982]).to(torch_device)
expected_slice = torch.tensor([-0.1210, -0.6605, 0.1918]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

0 comments on commit 6e71db9

Please sign in to comment.