Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge committed Jan 27, 2022
1 parent c369419 commit 62b14c0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 18 deletions.
6 changes: 6 additions & 0 deletions docs/source/model_doc/convnext.mdx
Expand Up @@ -15,6 +15,7 @@ specific language governing permissions and limitations under the License.
## Overview

The ConvNeXT model was proposed in [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
ConvNeXT is a pure convolutional model (ConvNet), inspired by the design of Vision Transformers, that claims to outperform them.

The abstract from the paper is the following:

Expand All @@ -33,6 +34,11 @@ Tips:

This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).

## ConvNeXT specific outputs

[[autodoc]] models.convnext.modeling_convnext.ConvNextModelOutput


## ConvNextConfig

[[autodoc]] ConvNextConfig
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/convnext/configuration_convnext.py
Expand Up @@ -74,8 +74,8 @@ def __init__(
self,
num_channels=3,
num_stages=4,
dims=[96, 192, 384, 768],
depths=[3, 3, 9, 3],
dims=None,
depths=None,
hidden_act="gelu",
initializer_range=0.02,
layer_norm_eps=1e-12,
Expand All @@ -88,8 +88,8 @@ def __init__(

self.num_channels = num_channels
self.num_stages = num_stages
self.dims = dims
self.depths = depths
self.dims = [96, 192, 384, 768] if dims is None else dims
self.depths = [3, 3, 9, 3] if depths is None else depths
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
Expand Down
46 changes: 32 additions & 14 deletions src/transformers/models/convnext/modeling_convnext.py
Expand Up @@ -29,7 +29,6 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_convnext import ConvNextConfig
Expand All @@ -49,23 +48,43 @@
@dataclass
class ConvNextModelOutput(ModelOutput):
"""
Args:
Class for [`ConvNextModel`]'s outputs, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Output feature map of the last stage of the model.
Last hidden states (final feature map) of the last stage of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.dim[-1])`):
Global average pooling of the last feature map followed by a layernorm.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
height, width)`. Hidden-states of the model at the output of each layer.
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
"""

last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class ConvNextClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
Expand Down Expand Up @@ -115,17 +134,17 @@ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):

def forward(self, x):
if self.data_format == "channels_last":
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
return x


class ConvNextBlock(nn.Module):
"""This corresponds to the ConvNextBlock class in the original implementation.
class ConvNextLayer(nn.Module):
"""This corresponds to the Block class in the original implementation.
There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
Expand Down Expand Up @@ -251,7 +270,7 @@ def __init__(self, config):
for i in range(config.num_stages):
stage = nn.Sequential(
*[
ConvNextBlock(config, dim=config.dims[i], drop_path=dp_rates[cur + j])
ConvNextLayer(config, dim=config.dims[i], drop_path=dp_rates[cur + j])
for j in range(config.depths[i])
]
)
Expand Down Expand Up @@ -338,7 +357,7 @@ def __init__(self, config):
self.post_init()

@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=ConvNextClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, pixel_values=None, labels=None, output_hidden_states=None, return_dict=None):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -402,9 +421,8 @@ def forward(self, pixel_values=None, labels=None, output_hidden_states=None, ret
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
return ConvNextClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=None,
)

0 comments on commit 62b14c0

Please sign in to comment.