diff --git a/docs/source/index.mdx b/docs/source/index.mdx index a1c32e5dadfcd..1ad5f0a908848 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -179,7 +179,7 @@ Flax), PyTorch, and/or TensorFlow. | Canine | ✅ | ❌ | ✅ | ❌ | ❌ | | CLIP | ✅ | ✅ | ✅ | ✅ | ✅ | | ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ConvNext | ❌ | ❌ | ✅ | ❌ | ❌ | +| ConvNext | ❌ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/model_doc/convnext.mdx b/docs/source/model_doc/convnext.mdx index e3a04d371e64c..4d46248565f94 100644 --- a/docs/source/model_doc/convnext.mdx +++ b/docs/source/model_doc/convnext.mdx @@ -37,7 +37,8 @@ alt="drawing" width="600"/> ConvNeXT architecture. Taken from the original paper. -This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt). +This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [ariG23498](https://github.com/ariG23498), +[gante](https://github.com/gante), and [sayakpaul](https://github.com/sayakpaul) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt). ## ConvNeXT specific outputs @@ -63,4 +64,16 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ## ConvNextForImageClassification [[autodoc]] ConvNextForImageClassification - - forward \ No newline at end of file + - forward + + +## TFConvNextModel + +[[autodoc]] TFConvNextModel + - call + + +## TFConvNextForImageClassification + +[[autodoc]] TFConvNextForImageClassification + - call \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d97e582c35809..3e858cebee74c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1743,6 +1743,13 @@ "TFConvBertPreTrainedModel", ] ) + _import_structure["models.convnext"].extend( + [ + "TFConvNextForImageClassification", + "TFConvNextModel", + "TFConvNextPreTrainedModel", + ] + ) _import_structure["models.ctrl"].extend( [ "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3751,6 +3758,7 @@ TFConvBertModel, TFConvBertPreTrainedModel, ) + from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel from .models.ctrl import ( TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, TFCTRLForSequenceClassification, diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4debfad93153e..2ab3e79381171 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -311,9 +311,10 @@ def booleans_processing(config, **kwargs): final_booleans = {} if tf.executing_eagerly(): - final_booleans["output_attentions"] = ( - kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions - ) + # Pure conv models (such as ConvNext) do not have `output_attentions` + final_booleans["output_attentions"] = kwargs.get("output_attentions", None) + if final_booleans["output_attentions"] is None: + final_booleans["output_attentions"] = config.output_attentions final_booleans["output_hidden_states"] = ( kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index cd4158bc7dd46..1b95cfa01d545 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -36,6 +36,7 @@ ("rembert", "TFRemBertModel"), ("roformer", "TFRoFormerModel"), ("convbert", "TFConvBertModel"), + ("convnext", "TFConvNextModel"), ("led", "TFLEDModel"), ("lxmert", "TFLxmertModel"), ("mt5", "TFMT5Model"), @@ -155,6 +156,7 @@ [ # Model for Image-classsification ("vit", "TFViTForImageClassification"), + ("convnext", "TFConvNextForImageClassification"), ] ) diff --git a/src/transformers/models/convnext/__init__.py b/src/transformers/models/convnext/__init__.py index cdc064d3c994a..a627c462e9ba4 100644 --- a/src/transformers/models/convnext/__init__.py +++ b/src/transformers/models/convnext/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING # rely on isort to merge the imports -from ...file_utils import _LazyModule, is_torch_available, is_vision_available +from ...file_utils import _LazyModule, is_tf_available, is_torch_available, is_vision_available _import_structure = { @@ -36,6 +36,12 @@ "ConvNextPreTrainedModel", ] +if is_tf_available(): + _import_structure["modeling_tf_convnext"] = [ + "TFConvNextForImageClassification", + "TFConvNextModel", + "TFConvNextPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig @@ -51,6 +57,9 @@ ConvNextPreTrainedModel, ) + if is_tf_available(): + from .modeling_convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel + else: import sys diff --git a/src/transformers/models/convnext/configuration_convnext.py b/src/transformers/models/convnext/configuration_convnext.py index 8d99c657cc639..74067ad337bbf 100644 --- a/src/transformers/models/convnext/configuration_convnext.py +++ b/src/transformers/models/convnext/configuration_convnext.py @@ -85,6 +85,7 @@ def __init__( is_encoder_decoder=False, layer_scale_init_value=1e-6, drop_path_rate=0.0, + image_size=224, **kwargs ): super().__init__(**kwargs) @@ -99,3 +100,4 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.layer_scale_init_value = layer_scale_init_value self.drop_path_rate = drop_path_rate + self.image_size = image_size diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py new file mode 100644 index 0000000000000..fbb436059340f --- /dev/null +++ b/src/transformers/models/convnext/modeling_tf_convnext.py @@ -0,0 +1,618 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 ConvNext model.""" + + +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + input_processing, + keras_serializable, +) +from ...utils import logging +from .configuration_convnext import ConvNextConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "ConvNextConfig" +_CHECKPOINT_FOR_DOC = "facebook/convnext-tiny-224" + + +class TFConvNextDropPath(tf.keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFConvNextEmbeddings(tf.keras.layers.Layer): + """This class is comparable to (and inspired by) the SwinEmbeddings class + found in src/transformers/models/swin/modeling_swin.py. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.patch_embeddings = tf.keras.layers.Conv2D( + filters=config.hidden_sizes[0], + kernel_size=config.patch_size, + strides=config.patch_size, + name="patch_embeddings", + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + ) + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") + + def call(self, pixel_values): + if isinstance(pixel_values, dict): + pixel_values = pixel_values["pixel_values"] + + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.layernorm(embeddings) + return embeddings + + +class TFConvNextLayer(tf.keras.layers.Layer): + """This corresponds to the `Block` class in the original implementation. + + There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C, + H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back + + The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow + NHWC ordering, we can just apply the operations straight-away without the permutation. + + Args: + config ([`ConvNextConfig`]): Model configuration class. + dim (`int`): Number of input channels. + drop_path (`float`): Stochastic depth rate. Default: 0.0. + """ + + def __init__(self, config, dim, drop_path=0.0, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.config = config + self.dwconv = tf.keras.layers.Conv2D( + filters=dim, + kernel_size=7, + padding="same", + groups=dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="dwconv", + ) # depthwise conv + self.layernorm = tf.keras.layers.LayerNormalization( + epsilon=1e-6, + name="layernorm", + ) + self.pwconv1 = tf.keras.layers.Dense( + units=4 * dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="pwconv1", + ) # pointwise/1x1 convs, implemented with linear layers + self.act = get_tf_activation(config.hidden_act) + self.pwconv2 = tf.keras.layers.Dense( + units=dim, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="pwconv2", + ) + # Using `layers.Activation` instead of `tf.identity` to better control `training` + # behaviour. + self.drop_path = ( + TFConvNextDropPath(drop_path, name="drop_path") + if drop_path > 0.0 + else tf.keras.layers.Activation("linear", name="drop_path") + ) + + def build(self, input_shape: tf.TensorShape): + # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa) + self.layer_scale_parameter = ( + self.add_weight( + shape=(self.dim,), + initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_parameter", + ) + if self.config.layer_scale_init_value > 0 + else None + ) + super().build(input_shape) + + def call(self, hidden_states, training=False): + input = hidden_states + x = self.dwconv(hidden_states) + x = self.layernorm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.layer_scale_parameter is not None: + x = self.layer_scale_parameter * x + + x = input + self.drop_path(x, training=training) + return x + + +class TFConvNextStage(tf.keras.layers.Layer): + """ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks. + + Args: + config ([`ConvNextConfig`]): Model configuration class. + in_channels (`int`): Number of input channels. + out_channels (`int`): Number of output channels. + depth (`int`): Number of residual blocks. + drop_path_rates(`List[float]`): Stochastic depth rates for each layer. + """ + + def __init__( + self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None, **kwargs + ): + super().__init__(**kwargs) + if in_channels != out_channels or stride > 1: + self.downsampling_layer = [ + tf.keras.layers.LayerNormalization( + epsilon=1e-6, + name="downsampling_layer.0", + ), + # Inputs to this layer will follow NHWC format since we + # transposed the inputs from NCHW to NHWC in the `TFConvNextEmbeddings` + # layer. All the outputs throughout the model will be in NHWC + # from this point on until the output where we again change to + # NCHW. + tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="downsampling_layer.1", + ), + ] + else: + self.downsampling_layer = [tf.identity] + + drop_path_rates = drop_path_rates or [0.0] * depth + self.layers = [ + TFConvNextLayer( + config, + dim=out_channels, + drop_path=drop_path_rates[j], + name=f"layers.{j}", + ) + for j in range(depth) + ] + + def call(self, hidden_states): + for layer in self.downsampling_layer: + hidden_states = layer(hidden_states) + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class TFConvNextEncoder(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.stages = [] + drop_path_rates = [x for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))] + cur = 0 + prev_chs = config.hidden_sizes[0] + for i in range(config.num_stages): + out_chs = config.hidden_sizes[i] + stage = TFConvNextStage( + config, + in_channels=prev_chs, + out_channels=out_chs, + stride=2 if i > 0 else 1, + depth=config.depths[i], + drop_path_rates=drop_path_rates[cur], + name=f"stages.{i}", + ) + self.stages.append(stage) + cur += config.depths[i] + prev_chs = out_chs + + def call(self, hidden_states, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.stages): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@keras_serializable +class TFConvNextMainLayer(tf.keras.layers.Layer): + config_class = ConvNextConfig + + def __init__(self, config: ConvNextConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embeddings = TFConvNextEmbeddings(config, name="embeddings") + self.encoder = TFConvNextEncoder(config, name="encoder") + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + # We are setting the `data_format` like so because from here on we will revert to the + # NCHW output format + self.pooler = tf.keras.layers.GlobalAvgPool2D(data_format="channels_first") if add_pooling_layer else None + + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + if inputs["pixel_values"] is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(inputs["pixel_values"], training=inputs["training"]) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=inputs["training"], + ) + + last_hidden_state = encoder_outputs[0] + # Change to NCHW output format have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2)) + pooled_output = self.layernorm(self.pooler(last_hidden_state)) + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + +class TFConvNextPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ConvNextConfig + base_model_prefix = "convnext" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=( + 3, + self.config.num_channels, + self.config.image_size, + self.config.image_size, + ), + dtype=tf.float32, + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + return self.call(inputs) + + +CONVNEXT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + + + Parameters: + config ([`ConvNextConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CONVNEXT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`ConvNextFeatureExtractor`]. See + [`ConvNextFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. +""" + + +@add_start_docstrings( + "The bare ConvNext model outputting raw features without any specific head on top.", + CONVNEXT_START_DOCSTRING, +) +class TFConvNextModel(TFConvNextPreTrainedModel): + def __init__(self, config, *inputs, add_pooling_layer=True, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.convnext = TFConvNextMainLayer(config, add_pooling_layer=add_pooling_layer, name="convnext") + + @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import ConvNextFeatureExtractor, TFConvNextModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") + >>> model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224") + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + if inputs["pixel_values"] is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.convnext( + pixel_values=inputs["pixel_values"], + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=inputs["training"], + ) + + if not return_dict: + return (outputs[0],) + outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=outputs.last_hidden_state, + pooler_output=outputs.pooler_output, + hidden_states=outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + CONVNEXT_START_DOCSTRING, +) +class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ConvNextConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.convnext = TFConvNextMainLayer(config, name="convnext") + + # Classifier head + self.classifier = tf.keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="classifier", + ) + + @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import ConvNextFeatureExtractor, TFConvNextForImageClassification + >>> import tensorflow as tf + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") + >>> model = TFViTForImageClassification.from_pretrained("facebook/convnext-tiny-224") + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] + >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + if inputs["pixel_values"] is None: + raise ValueError("You have to specify pixel_values") + + outputs = self.convnext( + inputs["pixel_values"], + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=inputs["training"], + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + + if not inputs["return_dict"]: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 6bba825a88978..ae7ffee3fb9e8 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -641,6 +641,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +class TFConvNextForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFConvNextModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFConvNextPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/convnext/test_modeling_tf_convnext.py b/tests/convnext/test_modeling_tf_convnext.py new file mode 100644 index 0000000000000..880e006f1abf2 --- /dev/null +++ b/tests/convnext/test_modeling_tf_convnext.py @@ -0,0 +1,281 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow ConvNext model. """ + +import inspect +import unittest +from typing import List, Tuple + +from transformers import ConvNextConfig +from transformers.file_utils import cached_property, is_tf_available, is_vision_available +from transformers.testing_utils import require_tf, require_vision, slow + +from ..test_configuration_common import ConfigTester +from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFConvNextForImageClassification, TFConvNextModel + + +if is_vision_available(): + from PIL import Image + + from transformers import ConvNextFeatureExtractor + + +class TFConvNextModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + num_channels=3, + num_stages=4, + hidden_sizes=[10, 20, 30, 40], + depths=[2, 2, 3, 2], + is_training=True, + use_labels=True, + intermediate_size=37, + hidden_act="gelu", + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.num_stages = num_stages + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return ConvNextConfig( + num_channels=self.num_channels, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + num_stages=self.num_stages, + hidden_act=self.hidden_act, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFConvNextModel(config=config) + result = model(pixel_values, training=False) + # expected last hidden states: B, C, H // 32, W // 32 + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = TFConvNextForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else () + + test_pruning = False + test_onnx = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = TFConvNextModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=ConvNextConfig, + has_text_modality=False, + hidden_size=37, + ) + + @unittest.skip(reason="ConvNext does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="ConvNext does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Model doesn't have attention layers") + def test_attention_outputs(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + # ConvNext's feature maps are of shape (batch_size, num_channels, height, width) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_size // 4, self.model_tester.image_size // 4], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Since ConvNext does not have any attention we need to rewrite this test. + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + all(tf.equal(tuple_object, dict_object)), + msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}", + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224") + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +@require_vision +class TFConvNextModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") if is_vision_available() else None + ) + + @slow + def test_inference_image_classification_head(self): + model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224") + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([-0.0260, -0.4739, 0.1911]) + + tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index f293d8126fe59..142bff7cae06e 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -474,8 +474,8 @@ def test_compile_tf_model(self): ), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), } - # TODO: A better way to handle vision models - elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]: + # `pixel_values` implies that the input is an image + elif model_class.main_input_name == "pixel_values": inputs = tf.keras.Input( batch_shape=( 3,