diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index a1c32e5dadfcd..1ad5f0a908848 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -179,7 +179,7 @@ Flax), PyTorch, and/or TensorFlow.
| Canine | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
| ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
-| ConvNext | ❌ | ❌ | ✅ | ❌ | ❌ |
+| ConvNext | ❌ | ❌ | ✅ | ✅ | ❌ |
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ |
diff --git a/docs/source/model_doc/convnext.mdx b/docs/source/model_doc/convnext.mdx
index e3a04d371e64c..4d46248565f94 100644
--- a/docs/source/model_doc/convnext.mdx
+++ b/docs/source/model_doc/convnext.mdx
@@ -37,7 +37,8 @@ alt="drawing" width="600"/>
ConvNeXT architecture. Taken from the original paper.
-This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
+This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [ariG23498](https://github.com/ariG23498),
+[gante](https://github.com/gante), and [sayakpaul](https://github.com/sayakpaul) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
## ConvNeXT specific outputs
@@ -63,4 +64,16 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
## ConvNextForImageClassification
[[autodoc]] ConvNextForImageClassification
- - forward
\ No newline at end of file
+ - forward
+
+
+## TFConvNextModel
+
+[[autodoc]] TFConvNextModel
+ - call
+
+
+## TFConvNextForImageClassification
+
+[[autodoc]] TFConvNextForImageClassification
+ - call
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index d97e582c35809..3e858cebee74c 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1743,6 +1743,13 @@
"TFConvBertPreTrainedModel",
]
)
+ _import_structure["models.convnext"].extend(
+ [
+ "TFConvNextForImageClassification",
+ "TFConvNextModel",
+ "TFConvNextPreTrainedModel",
+ ]
+ )
_import_structure["models.ctrl"].extend(
[
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -3751,6 +3758,7 @@
TFConvBertModel,
TFConvBertPreTrainedModel,
)
+ from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
from .models.ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLForSequenceClassification,
diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py
index 4debfad93153e..2ab3e79381171 100644
--- a/src/transformers/modeling_tf_utils.py
+++ b/src/transformers/modeling_tf_utils.py
@@ -311,9 +311,10 @@ def booleans_processing(config, **kwargs):
final_booleans = {}
if tf.executing_eagerly():
- final_booleans["output_attentions"] = (
- kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
- )
+ # Pure conv models (such as ConvNext) do not have `output_attentions`
+ final_booleans["output_attentions"] = kwargs.get("output_attentions", None)
+ if final_booleans["output_attentions"] is None:
+ final_booleans["output_attentions"] = config.output_attentions
final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"]
if kwargs["output_hidden_states"] is not None
diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py
index cd4158bc7dd46..1b95cfa01d545 100644
--- a/src/transformers/models/auto/modeling_tf_auto.py
+++ b/src/transformers/models/auto/modeling_tf_auto.py
@@ -36,6 +36,7 @@
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
+ ("convnext", "TFConvNextModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
@@ -155,6 +156,7 @@
[
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
+ ("convnext", "TFConvNextForImageClassification"),
]
)
diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py
index cdc064d3c994a..a627c462e9ba4 100644
--- a/src/transformers/models/convnext/__init__.py
+++ b/src/transformers/models/convnext/__init__.py
@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
-from ...file_utils import _LazyModule, is_torch_available, is_vision_available
+from ...file_utils import _LazyModule, is_tf_available, is_torch_available, is_vision_available
_import_structure = {
@@ -36,6 +36,12 @@
"ConvNextPreTrainedModel",
]
+if is_tf_available():
+ _import_structure["modeling_tf_convnext"] = [
+ "TFConvNextForImageClassification",
+ "TFConvNextModel",
+ "TFConvNextPreTrainedModel",
+ ]
if TYPE_CHECKING:
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
@@ -51,6 +57,9 @@
ConvNextPreTrainedModel,
)
+ if is_tf_available():
+ from .modeling_convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
+
else:
import sys
diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py
index 8d99c657cc639..74067ad337bbf 100644
--- a/src/transformers/models/convnext/configuration_convnext.py
+++ b/src/transformers/models/convnext/configuration_convnext.py
@@ -85,6 +85,7 @@ def __init__(
is_encoder_decoder=False,
layer_scale_init_value=1e-6,
drop_path_rate=0.0,
+ image_size=224,
**kwargs
):
super().__init__(**kwargs)
@@ -99,3 +100,4 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate
+ self.image_size = image_size
diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py
new file mode 100644
index 0000000000000..fbb436059340f
--- /dev/null
+++ b/src/transformers/models/convnext/modeling_tf_convnext.py
@@ -0,0 +1,618 @@
+# coding=utf-8
+# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" TF 2.0 ConvNext model."""
+
+
+from typing import Dict, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
+from ...modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ input_processing,
+ keras_serializable,
+)
+from ...utils import logging
+from .configuration_convnext import ConvNextConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ConvNextConfig"
+_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224"
+
+
+class TFConvNextDropPath(tf.keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFConvNextEmbeddings(tf.keras.layers.Layer):
+ """This class is comparable to (and inspired by) the SwinEmbeddings class
+ found in src/transformers/models/swin/modeling_swin.py.
+ """
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_embeddings = tf.keras.layers.Conv2D(
+ filters=config.hidden_sizes[0],
+ kernel_size=config.patch_size,
+ strides=config.patch_size,
+ name="patch_embeddings",
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ )
+ self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
+
+ def call(self, pixel_values):
+ if isinstance(pixel_values, dict):
+ pixel_values = pixel_values["pixel_values"]
+
+ # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
+ # So change the input format from `NCHW` to `NHWC`.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.layernorm(embeddings)
+ return embeddings
+
+
+class TFConvNextLayer(tf.keras.layers.Layer):
+ """This corresponds to the `Block` class in the original implementation.
+
+ There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
+ H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
+
+ The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
+ NHWC ordering, we can just apply the operations straight-away without the permutation.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ dim (`int`): Number of input channels.
+ drop_path (`float`): Stochastic depth rate. Default: 0.0.
+ """
+
+ def __init__(self, config, dim, drop_path=0.0, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.config = config
+ self.dwconv = tf.keras.layers.Conv2D(
+ filters=dim,
+ kernel_size=7,
+ padding="same",
+ groups=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="dwconv",
+ ) # depthwise conv
+ self.layernorm = tf.keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="layernorm",
+ )
+ self.pwconv1 = tf.keras.layers.Dense(
+ units=4 * dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv1",
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = get_tf_activation(config.hidden_act)
+ self.pwconv2 = tf.keras.layers.Dense(
+ units=dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="pwconv2",
+ )
+ # Using `layers.Activation` instead of `tf.identity` to better control `training`
+ # behaviour.
+ self.drop_path = (
+ TFConvNextDropPath(drop_path, name="drop_path")
+ if drop_path > 0.0
+ else tf.keras.layers.Activation("linear", name="drop_path")
+ )
+
+ def build(self, input_shape: tf.TensorShape):
+ # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
+ self.layer_scale_parameter = (
+ self.add_weight(
+ shape=(self.dim,),
+ initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_parameter",
+ )
+ if self.config.layer_scale_init_value > 0
+ else None
+ )
+ super().build(input_shape)
+
+ def call(self, hidden_states, training=False):
+ input = hidden_states
+ x = self.dwconv(hidden_states)
+ x = self.layernorm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.layer_scale_parameter is not None:
+ x = self.layer_scale_parameter * x
+
+ x = input + self.drop_path(x, training=training)
+ return x
+
+
+class TFConvNextStage(tf.keras.layers.Layer):
+ """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
+
+ Args:
+ config ([`ConvNextConfig`]): Model configuration class.
+ in_channels (`int`): Number of input channels.
+ out_channels (`int`): Number of output channels.
+ depth (`int`): Number of residual blocks.
+ drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
+ """
+
+ def __init__(
+ self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None, **kwargs
+ ):
+ super().__init__(**kwargs)
+ if in_channels != out_channels or stride > 1:
+ self.downsampling_layer = [
+ tf.keras.layers.LayerNormalization(
+ epsilon=1e-6,
+ name="downsampling_layer.0",
+ ),
+ # Inputs to this layer will follow NHWC format since we
+ # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings`
+ # layer. All the outputs throughout the model will be in NHWC
+ # from this point on until the output where we again change to
+ # NCHW.
+ tf.keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=kernel_size,
+ strides=stride,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="downsampling_layer.1",
+ ),
+ ]
+ else:
+ self.downsampling_layer = [tf.identity]
+
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ self.layers = [
+ TFConvNextLayer(
+ config,
+ dim=out_channels,
+ drop_path=drop_path_rates[j],
+ name=f"layers.{j}",
+ )
+ for j in range(depth)
+ ]
+
+ def call(self, hidden_states):
+ for layer in self.downsampling_layer:
+ hidden_states = layer(hidden_states)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class TFConvNextEncoder(tf.keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.stages = []
+ drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
+ cur = 0
+ prev_chs = config.hidden_sizes[0]
+ for i in range(config.num_stages):
+ out_chs = config.hidden_sizes[i]
+ stage = TFConvNextStage(
+ config,
+ in_channels=prev_chs,
+ out_channels=out_chs,
+ stride=2 if i > 0 else 1,
+ depth=config.depths[i],
+ drop_path_rates=drop_path_rates[cur],
+ name=f"stages.{i}",
+ )
+ self.stages.append(stage)
+ cur += config.depths[i]
+ prev_chs = out_chs
+
+ def call(self, hidden_states, output_hidden_states=False, return_dict=True):
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer_module in enumerate(self.stages):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = layer_module(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
+
+
+@keras_serializable
+class TFConvNextMainLayer(tf.keras.layers.Layer):
+ config_class = ConvNextConfig
+
+ def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embeddings = TFConvNextEmbeddings(config, name="embeddings")
+ self.encoder = TFConvNextEncoder(config, name="encoder")
+ self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+ # We are setting the `data_format` like so because from here on we will revert to the
+ # NCHW output format
+ self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None
+
+ def call(
+ self,
+ pixel_values: Optional[TFModelInputType] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ **kwargs,
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ inputs = input_processing(
+ func=self.call,
+ config=self.config,
+ input_ids=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ kwargs_call=kwargs,
+ )
+
+ if "input_ids" in inputs:
+ inputs["pixel_values"] = inputs.pop("input_ids")
+
+ if inputs["pixel_values"] is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(inputs["pixel_values"], training=inputs["training"])
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=inputs["training"],
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ # Change to NCHW output format have uniformity in the modules
+ last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
+ pooled_output = self.layernorm(self.pooler(last_hidden_state))
+
+ # Change the other hidden state outputs to NCHW as well
+ if output_hidden_states:
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ )
+
+
+class TFConvNextPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ConvNextConfig
+ base_model_prefix = "convnext"
+ main_input_name = "pixel_values"
+
+ @property
+ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
+ """
+ Dummy inputs to build the network.
+
+ Returns:
+ `Dict[str, tf.Tensor]`: The dummy inputs.
+ """
+ VISION_DUMMY_INPUTS = tf.random.uniform(
+ shape=(
+ 3,
+ self.config.num_channels,
+ self.config.image_size,
+ self.config.image_size,
+ ),
+ dtype=tf.float32,
+ )
+ return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
+
+ @tf.function(
+ input_signature=[
+ {
+ "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
+ }
+ ]
+ )
+ def serving(self, inputs):
+ """
+ Method used for serving the model.
+
+ Args:
+ inputs (`Dict[str, tf.Tensor]`):
+ The input of the saved model as a dictionary of tensors.
+ """
+ return self.call(inputs)
+
+
+CONVNEXT_START_DOCSTRING = r"""
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TF 2.0 models accepts two formats as inputs:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional arguments.
+
+ This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
+ tensors in the first argument of the model call function: `model(inputs)`.
+
+
+
+ Parameters:
+ config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CONVNEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`ConvNextFeatureExtractor`]. See
+ [`ConvNextFeatureExtractor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
+ in eager mode, in graph mode the value will always be set to True.
+"""
+
+
+@add_start_docstrings(
+ "The bare ConvNext model outputting raw features without any specific head on top.",
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextModel(TFConvNextPreTrainedModel):
+ def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext")
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: Optional[TFModelInputType] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ **kwargs,
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ConvNextFeatureExtractor, TFConvNextModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ inputs = input_processing(
+ func=self.call,
+ config=self.config,
+ input_ids=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ kwargs_call=kwargs,
+ )
+
+ if "input_ids" in inputs:
+ inputs["pixel_values"] = inputs.pop("input_ids")
+
+ if inputs["pixel_values"] is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ pixel_values=inputs["pixel_values"],
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=inputs["training"],
+ )
+
+ if not return_dict:
+ return (outputs[0],) + outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=outputs.last_hidden_state,
+ pooler_output=outputs.pooler_output,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ CONVNEXT_START_DOCSTRING,
+)
+class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: ConvNextConfig, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.convnext = TFConvNextMainLayer(config, name="convnext")
+
+ # Classifier head
+ self.classifier = tf.keras.layers.Dense(
+ units=config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ bias_initializer="zeros",
+ name="classifier",
+ )
+
+ @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ pixel_values: Optional[TFModelInputType] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
+ training: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ConvNextFeatureExtractor, TFConvNextForImageClassification
+ >>> import tensorflow as tf
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224")
+ >>> model = TFViTForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
+ ```"""
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ inputs = input_processing(
+ func=self.call,
+ config=self.config,
+ input_ids=pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ labels=labels,
+ training=training,
+ kwargs_call=kwargs,
+ )
+
+ if "input_ids" in inputs:
+ inputs["pixel_values"] = inputs.pop("input_ids")
+
+ if inputs["pixel_values"] is None:
+ raise ValueError("You have to specify pixel_values")
+
+ outputs = self.convnext(
+ inputs["pixel_values"],
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=inputs["training"],
+ )
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+ loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
+
+ if not inputs["return_dict"]:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py
index 6bba825a88978..ae7ffee3fb9e8 100644
--- a/src/transformers/utils/dummy_tf_objects.py
+++ b/src/transformers/utils/dummy_tf_objects.py
@@ -641,6 +641,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
+class TFConvNextForImageClassification(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFConvNextModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
+class TFConvNextPreTrainedModel(metaclass=DummyObject):
+ _backends = ["tf"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["tf"])
+
+
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/convnext/test_modeling_tf_convnext.py b/tests/convnext/test_modeling_tf_convnext.py
new file mode 100644
index 0000000000000..880e006f1abf2
--- /dev/null
+++ b/tests/convnext/test_modeling_tf_convnext.py
@@ -0,0 +1,281 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the TensorFlow ConvNext model. """
+
+import inspect
+import unittest
+from typing import List, Tuple
+
+from transformers import ConvNextConfig
+from transformers.file_utils import cached_property, is_tf_available, is_vision_available
+from transformers.testing_utils import require_tf, require_vision, slow
+
+from ..test_configuration_common import ConfigTester
+from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_tf_available():
+ import tensorflow as tf
+
+ from transformers import TFConvNextForImageClassification, TFConvNextModel
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ConvNextFeatureExtractor
+
+
+class TFConvNextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ num_channels=3,
+ num_stages=4,
+ hidden_sizes=[10, 20, 30, 40],
+ depths=[2, 2, 3, 2],
+ is_training=True,
+ use_labels=True,
+ intermediate_size=37,
+ hidden_act="gelu",
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.num_stages = num_stages
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return ConvNextConfig(
+ num_channels=self.num_channels,
+ hidden_sizes=self.hidden_sizes,
+ depths=self.depths,
+ num_stages=self.num_stages,
+ hidden_act=self.hidden_act,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = TFConvNextModel(config=config)
+ result = model(pixel_values, training=False)
+ # expected last hidden states: B, C, H // 32, W // 32
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = TFConvNextForImageClassification(config)
+ result = model(pixel_values, labels=labels, training=False)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_tf
+class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else ()
+
+ test_pruning = False
+ test_onnx = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = TFConvNextModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=ConvNextConfig,
+ has_text_modality=False,
+ hidden_size=37,
+ )
+
+ @unittest.skip(reason="ConvNext does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="ConvNext does not support input and output embeddings")
+ def test_model_common_attributes(self):
+ pass
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.call)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ @unittest.skip(reason="Model doesn't have attention layers")
+ def test_attention_outputs(self):
+ pass
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_stages = self.model_tester.num_stages
+ self.assertEqual(len(hidden_states), expected_num_stages + 1)
+
+ # ConvNext's feature maps are of shape (batch_size, num_channels, height, width)
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [self.model_tester.image_size // 4, self.model_tester.image_size // 4],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # Since ConvNext does not have any attention we need to rewrite this test.
+ def test_model_outputs_equivalence(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+ tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
+ dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
+
+ def recursive_check(tuple_object, dict_object):
+ if isinstance(tuple_object, (List, Tuple)):
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
+ recursive_check(tuple_iterable_value, dict_iterable_value)
+ elif tuple_object is None:
+ return
+ else:
+ self.assertTrue(
+ all(tf.equal(tuple_object, dict_object)),
+ msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
+ )
+
+ recursive_check(tuple_output, dict_output)
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs)
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_tf
+@require_vision
+class TFConvNextModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return (
+ ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") if is_vision_available() else None
+ )
+
+ @slow
+ def test_inference_image_classification_head(self):
+ model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="tf")
+
+ # forward pass
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = tf.TensorShape((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = tf.constant([-0.0260, -0.4739, 0.1911])
+
+ tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index f293d8126fe59..142bff7cae06e 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -474,8 +474,8 @@ def test_compile_tf_model(self):
),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
}
- # TODO: A better way to handle vision models
- elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]:
+ # `pixel_values` implies that the input is an image
+ elif model_class.main_input_name == "pixel_values":
inputs = tf.keras.Input(
batch_shape=(
3,