From 712429bc34aa763d14fe04a3168c0e52a94bd3a5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 10 Aug 2022 18:27:09 +0530 Subject: [PATCH 01/20] initial implementation. --- .../models/mobilevit/modeling_tf_mobilevit.py | 521 ++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 src/transformers/models/mobilevit/modeling_tf_mobilevit.py diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py new file mode 100644 index 0000000000000..69bc06e5affd1 --- /dev/null +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -0,0 +1,521 @@ +# coding=utf-8 +# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE +""" TensorFlow 2.0 MobileViT model.""" + +import math +from typing import Dict, Optional, Tuple, Union + +import tensorflow as tf + + +from ...activations_tf import get_tf_activation +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput +from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs +from ...tf_utils import shape_list, stable_softmax +from ...utils import logging +from .configuration_mobilevit import MobileViTConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "MobileViTConfig" +_FEAT_EXTRACTOR_FOR_DOC = "MobileViTFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "apple/mobilevit-small" +_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "apple/mobilevit-small", + "apple/mobilevit-x-small", + "apple/mobilevit-xx-small", + "apple/deeplabv3-mobilevit-small", + "apple/deeplabv3-mobilevit-x-small", + "apple/deeplabv3-mobilevit-xx-small", + # See all MobileViT models at https://huggingface.co/models?filter=mobilevit +] + + +def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int: + """ + Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the + original TensorFlow repo. It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_value < 0.9 * value: + new_value += divisor + return int(new_value) + + +class TFMobileViTConvLayer(tf.keras.layers.Layer): + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = False, + dilation: int = 1, + use_normalization: bool = True, + use_activation: Union[bool, str] = True, + **kwargs + ) -> None: + super().__init__(**kwargs) + padding = int((kernel_size - 1) / 2) * dilation + self.padding = tf.keras.layers.ZeroPadding2D(padding) + + if out_channels % groups != 0: + raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.") + + self.convolution = tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + padding="VALID", + dilation=dilation, + groups=groups, + use_bias=bias, + name="convolution" + ) + + if use_normalization: + self.normalization = tf.keras.layers.BatchNormalization( + epsilon=1e-5, momentum=0.1, + name="normalization" + ) + else: + self.normalization = None + + if use_activation: + if isinstance(use_activation, str): + self.activation = get_tf_activation(use_activation) + elif isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + else: + self.activation = None + + def call(self, features: tf.Tensor) -> tf.Tensor: + features = self.convolution(self.padding(features)) + if self.normalization is not None: + features = self.normalization(features) + if self.activation is not None: + features = self.activation(features) + return features + + +class TFMobileViTInvertedResidual(tf.keras.layers.Layer): + """ + Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381 + """ + + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs + ) -> None: + super().__init__(**kwargs) + expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8) + + if stride not in [1, 2]: + raise ValueError(f"Invalid stride {stride}.") + + self.use_residual = (stride == 1) and (in_channels == out_channels) + + self.expand_1x1 = TFMobileViTConvLayer( + config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1" + ) + + self.conv_3x3 = TFMobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=expanded_channels, + kernel_size=3, + strides=stride, + groups=expanded_channels, + dilation=dilation, + name="conv_3x3" + ) + + self.reduce_1x1 = TFMobileViTConvLayer( + config, + in_channels=expanded_channels, + out_channels=out_channels, + kernel_size=1, + use_activation=False, + name="reduce_1x1" + ) + + def call(self, features: tf.Tensor) -> tf.Tensor: + residual = features + + features = self.expand_1x1(features) + features = self.conv_3x3(features) + features = self.reduce_1x1(features) + + return residual + features if self.use_residual else features + + +class MobileViTMobileNetLayer(tf.keras.layers.Layer): + def __init__( + self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1, **kwargs + ) -> None: + super().__init__( **kwargs) + + self.layers = [] + for i in range(num_stages): + layer = TFMobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if i == 0 else 1, + name=f"layer.{i}" + ) + self.layers.append(layer) + in_channels = out_channels + + def call(self, features: tf.Tensor) -> tf.Tensor: + for layer_module in self.layers: + features = layer_module(features) + return features + + +class TFMobileViTSelfAttention(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + + if hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + scale = tf.cast(self.attention_head_size, dtype=tf.float32) + self.scale = tf.math.sqrt(scale) + + self.query = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query") + self.key = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key") + self.value = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value") + + self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + (self.num_attention_heads, self.attention_head_size) + x = tf.reshape(x, shape=new_x_shape) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + batch_size = shape_list(hidden_states)[0] + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + attention_scores = attention_scores / self.scale + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = tf.matmul(attention_probs, value_layer) + + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size)) + return context_layer + + +class TFMobileViTSelfOutput(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense(hidden_size, name="dense") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class TFMobileViTAttention(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention") + self.output = TFMobileViTSelfOutput(config, hidden_size, name="output") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + self_outputs = self.attention(hidden_states) + attention_output = self.output(self_outputs) + return attention_output + + +class TFMobileViTIntermediate(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense(intermediate_size, name="dense") + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class TFMobileViTOutput(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.dense = tf.keras.layers.Layer(hidden_size, name="dense") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class TFMobileViTTransformerLayer(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: + super().__init__(**kwargs) + self.attention = TFMobileViTAttention(config, hidden_size, name="attention") + self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate") + self.output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output") + self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class TFMobileViTTransformer(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None: + super().__init__(**kwargs) + + self.layers = [] + for i in range(num_stages): + transformer_layer = TFMobileViTTransformerLayer( + config, + hidden_size=hidden_size, + intermediate_size=int(hidden_size * config.mlp_ratio), + name=f"layer.{i}" + ) + self.layers.append(transformer_layer) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class TFMobileViTLayer(tf.keras.layers.Layer): + """ + MobileViT block: https://arxiv.org/abs/2110.02178 + """ + + def __init__( + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int, + hidden_size: int, + num_stages: int, + dilation: int = 1, + **kwargs + ) -> None: + super().__init__(**kwargs) + self.patch_width = config.patch_size + self.patch_height = config.patch_size + + if stride == 2: + self.downsampling_layer = TFMobileViTInvertedResidual( + config, + in_channels=in_channels, + out_channels=out_channels, + stride=stride if dilation == 1 else 1, + dilation=dilation // 2 if dilation > 1 else 1, + name="downsampling_layer" + ) + in_channels = out_channels + else: + self.downsampling_layer = None + + self.conv_kxk = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + name="conv_kxk" + ) + + self.conv_1x1 = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=hidden_size, + kernel_size=1, + use_normalization=False, + use_activation=False, + name="conv_1x1" + ) + + self.transformer = TFMobileViTTransformer( + config, + hidden_size=hidden_size, + num_stages=num_stages, + name="transformer" + ) + + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + self.conv_projection = TFMobileViTConvLayer( + config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection" + ) + + self.fusion = TFMobileViTConvLayer( + config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion") + + def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = tf.cast(patch_width * patch_height, "int32") + + batch_size, orig_height, orig_width, channels = shape_list(features) + + new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") + new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") + + interpolate = False + if new_width != orig_width or new_height != orig_height: + # Note: Padding can be done, but then it needs to be handled in attention function. + features = tf.image.resize( + features, size=(new_height, new_width), method="bilinear" + ) + interpolate = True + + # number of patches along width and height + num_patch_width = new_width // patch_width + num_patch_height = new_height // patch_height + num_patches = num_patch_height * num_patch_width + + # convert from shape (batch_size, orig_height, orig_width, channels) + # to the shape (batch_size * patch_area, num_patches, channels) + features = tf.transpose(features, [0, 3, 1, 2]) + patches = tf.reshape(features, ( + batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width + )) + patches = tf.transpose(patches, [0, 2, 1, 3]) + patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area)) + patches = tf.transpose(patches, [0, 3, 2, 1]) + patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels)) + + info_dict = { + "orig_size": (orig_height, orig_width), + "batch_size": batch_size, + "channels": channels, + "interpolate": interpolate, + "num_patches": num_patches, + "num_patches_width": num_patch_width, + "num_patches_height": num_patch_height, + } + return patches, info_dict + + def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor: + patch_width, patch_height = self.patch_width, self.patch_height + patch_area = int(patch_width * patch_height) + + batch_size = info_dict["batch_size"] + channels = info_dict["channels"] + num_patches = info_dict["num_patches"] + num_patch_height = info_dict["num_patches_height"] + num_patch_width = info_dict["num_patches_width"] + + # convert from shape (batch_size * patch_area, num_patches, channels) + # back to shape (batch_size, channels, orig_height, orig_width) + features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1)) + features = tf.transpose(features, perm=(0, 3, 2, 1)) + features = tf.reshape(features, ( + batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width + )) + features = tf.transpose(features, perm=(0, 2, 1, 3)) + features = tf.reshape(features, ( + batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width + )) + features = tf.transpose(features, perm=(0, 2, 3, 1)) + + if info_dict["interpolate"]: + features = tf.image.resize( + features, size=info_dict["orig_size"], method="bilinear" + ) + + return features + + def call(self, features: tf.Tensor) -> tf.Tensor: + # reduce spatial dimensions if needed + if self.downsampling_layer: + features = self.downsampling_layer(features) + + residual = features + + # local representation + features = self.conv_kxk(features) + features = self.conv_1x1(features) + + # convert feature map to patches + patches, info_dict = self.unfolding(features) + + # learn global representations + patches = self.transformer(patches) + patches = self.layernorm(patches) + + # convert patches back to feature maps + features = self.folding(patches, info_dict) + + features = self.conv_projection(features) + features = self.fusion(tf.concat([residual, features], axis=-1)) + return features \ No newline at end of file From 2251838ec696d5bd5e1c82a084d894c4f1b38871 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 11 Aug 2022 18:14:59 +0530 Subject: [PATCH 02/20] add: working model till image classification. --- .../models/mobilevit/modeling_tf_mobilevit.py | 489 ++++++++++++++++-- 1 file changed, 439 insertions(+), 50 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 69bc06e5affd1..568da5be92d7c 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -21,7 +21,6 @@ import tensorflow as tf - from ...activations_tf import get_tf_activation from ...file_utils import ( add_code_sample_docstrings, @@ -29,7 +28,12 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFImageClassifierOutputWithNoAttention, + TFSemanticSegmenterOutput, +) from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs from ...tf_utils import shape_list, stable_softmax from ...utils import logging @@ -104,17 +108,14 @@ def __init__( kernel_size=kernel_size, strides=stride, padding="VALID", - dilation=dilation, + dilation_rate=dilation, groups=groups, use_bias=bias, - name="convolution" + name="convolution", ) if use_normalization: - self.normalization = tf.keras.layers.BatchNormalization( - epsilon=1e-5, momentum=0.1, - name="normalization" - ) + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") else: self.normalization = None @@ -162,10 +163,10 @@ def __init__( in_channels=expanded_channels, out_channels=expanded_channels, kernel_size=3, - strides=stride, + stride=stride, groups=expanded_channels, dilation=dilation, - name="conv_3x3" + name="conv_3x3", ) self.reduce_1x1 = TFMobileViTConvLayer( @@ -174,7 +175,7 @@ def __init__( out_channels=out_channels, kernel_size=1, use_activation=False, - name="reduce_1x1" + name="reduce_1x1", ) def call(self, features: tf.Tensor) -> tf.Tensor: @@ -187,11 +188,17 @@ def call(self, features: tf.Tensor) -> tf.Tensor: return residual + features if self.use_residual else features -class MobileViTMobileNetLayer(tf.keras.layers.Layer): +class TFMobileViTMobileNetLayer(tf.keras.layers.Layer): def __init__( - self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1, **kwargs + self, + config: MobileViTConfig, + in_channels: int, + out_channels: int, + stride: int = 1, + num_stages: int = 1, + **kwargs ) -> None: - super().__init__( **kwargs) + super().__init__(**kwargs) self.layers = [] for i in range(num_stages): @@ -200,7 +207,7 @@ def __init__( in_channels=in_channels, out_channels=out_channels, stride=stride if i == 0 else 1, - name=f"layer.{i}" + name=f"layer.{i}", ) self.layers.append(layer) in_channels = out_channels @@ -234,8 +241,8 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - new_x_shape = shape_list(x)[:-1] + (self.num_attention_heads, self.attention_head_size) - x = tf.reshape(x, shape=new_x_shape) + batch_size = shape_list(x)[0] + x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, hidden_states: tf.Tensor) -> tf.Tensor: @@ -279,11 +286,11 @@ class TFMobileViTAttention(tf.keras.layers.Layer): def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: super().__init__(**kwargs) self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention") - self.output = TFMobileViTSelfOutput(config, hidden_size, name="output") + self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output") def call(self, hidden_states: tf.Tensor) -> tf.Tensor: self_outputs = self.attention(hidden_states) - attention_output = self.output(self_outputs) + attention_output = self.dense_output(self_outputs) return attention_output @@ -305,7 +312,7 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor: class TFMobileViTOutput(tf.keras.layers.Layer): def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None: super().__init__(**kwargs) - self.dense = tf.keras.layers.Layer(hidden_size, name="dense") + self.dense = tf.keras.layers.Dense(hidden_size, name="dense") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: @@ -320,9 +327,13 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: super().__init__(**kwargs) self.attention = TFMobileViTAttention(config, hidden_size, name="attention") self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate") - self.output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output") - self.layernorm_before = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") - self.layernorm_after = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output") + self.layernorm_before = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layernorm_before" + ) + self.layernorm_after = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layernorm_after" + ) def call(self, hidden_states: tf.Tensor) -> tf.Tensor: attention_output = self.attention(self.layernorm_before(hidden_states)) @@ -330,7 +341,7 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor: layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) - layer_output = self.output(layer_output, hidden_states) + layer_output = self.mobilevit_output(layer_output, hidden_states) return layer_output @@ -344,7 +355,7 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, * config, hidden_size=hidden_size, intermediate_size=int(hidden_size * config.mlp_ratio), - name=f"layer.{i}" + name=f"layer.{i}", ) self.layers.append(transformer_layer) @@ -381,7 +392,7 @@ def __init__( out_channels=out_channels, stride=stride if dilation == 1 else 1, dilation=dilation // 2 if dilation > 1 else 1, - name="downsampling_layer" + name="downsampling_layer", ) in_channels = out_channels else: @@ -392,7 +403,7 @@ def __init__( in_channels=in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size, - name="conv_kxk" + name="conv_kxk", ) self.conv_1x1 = TFMobileViTConvLayer( @@ -402,14 +413,11 @@ def __init__( kernel_size=1, use_normalization=False, use_activation=False, - name="conv_1x1" + name="conv_1x1", ) self.transformer = TFMobileViTTransformer( - config, - hidden_size=hidden_size, - num_stages=num_stages, - name="transformer" + config, hidden_size=hidden_size, num_stages=num_stages, name="transformer" ) self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") @@ -419,7 +427,12 @@ def __init__( ) self.fusion = TFMobileViTConvLayer( - config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion") + config, + in_channels=2 * in_channels, + out_channels=in_channels, + kernel_size=config.conv_kernel_size, + name="fusion", + ) def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: patch_width, patch_height = self.patch_width, self.patch_height @@ -433,9 +446,7 @@ def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: interpolate = False if new_width != orig_width or new_height != orig_height: # Note: Padding can be done, but then it needs to be handled in attention function. - features = tf.image.resize( - features, size=(new_height, new_width), method="bilinear" - ) + features = tf.image.resize(features, size=(new_height, new_width), method="bilinear") interpolate = True # number of patches along width and height @@ -446,9 +457,9 @@ def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: # convert from shape (batch_size, orig_height, orig_width, channels) # to the shape (batch_size * patch_area, num_patches, channels) features = tf.transpose(features, [0, 3, 1, 2]) - patches = tf.reshape(features, ( - batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width - )) + patches = tf.reshape( + features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width) + ) patches = tf.transpose(patches, [0, 2, 1, 3]) patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area)) patches = tf.transpose(patches, [0, 3, 2, 1]) @@ -479,19 +490,17 @@ def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor: # back to shape (batch_size, channels, orig_height, orig_width) features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1)) features = tf.transpose(features, perm=(0, 3, 2, 1)) - features = tf.reshape(features, ( - batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width - )) - features = tf.transpose(features, perm=(0, 2, 1, 3)) - features = tf.reshape(features, ( - batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width - )) + features = tf.reshape( + features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width) + ) + features = tf.transpose(features, perm=(0, 2, 1, 3)) + features = tf.reshape( + features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width) + ) features = tf.transpose(features, perm=(0, 2, 3, 1)) if info_dict["interpolate"]: - features = tf.image.resize( - features, size=info_dict["orig_size"], method="bilinear" - ) + features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear") return features @@ -518,4 +527,384 @@ def call(self, features: tf.Tensor) -> tf.Tensor: features = self.conv_projection(features) features = self.fusion(tf.concat([residual, features], axis=-1)) - return features \ No newline at end of file + return features + + +class TFMobileViTEncoder(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + self.layers = [] + + # segmentation architectures like DeepLab and PSPNet modify the strides + # of the classification backbones + dilate_layer_4 = dilate_layer_5 = False + if config.output_stride == 8: + dilate_layer_4 = True + dilate_layer_5 = True + elif config.output_stride == 16: + dilate_layer_5 = True + + dilation = 1 + + layer_1 = TFMobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[0], + out_channels=config.neck_hidden_sizes[1], + stride=1, + num_stages=1, + name="layer.0", + ) + self.layers.append(layer_1) + + layer_2 = TFMobileViTMobileNetLayer( + config, + in_channels=config.neck_hidden_sizes[1], + out_channels=config.neck_hidden_sizes[2], + stride=2, + num_stages=3, + name="layer.1", + ) + self.layers.append(layer_2) + + layer_3 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[2], + out_channels=config.neck_hidden_sizes[3], + stride=2, + hidden_size=config.hidden_sizes[0], + num_stages=2, + name="layer.2", + ) + self.layers.append(layer_3) + + if dilate_layer_4: + dilation *= 2 + + layer_4 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[3], + out_channels=config.neck_hidden_sizes[4], + stride=2, + hidden_size=config.hidden_sizes[1], + num_stages=4, + dilation=dilation, + name="layer.3", + ) + self.layers.append(layer_4) + + if dilate_layer_5: + dilation *= 2 + + layer_5 = TFMobileViTLayer( + config, + in_channels=config.neck_hidden_sizes[4], + out_channels=config.neck_hidden_sizes[5], + stride=2, + hidden_size=config.hidden_sizes[2], + num_stages=3, + dilation=dilation, + name="layer.4", + ) + self.layers.append(layer_5) + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, TFBaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + + for i, layer_module in enumerate(self.layers): + hidden_states = layer_module(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states) + + +@keras_serializable +class TFMobileViTMainLayer(tf.keras.layers.Layer): + config_class = MobileViTConfig + + def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs): + super().__init__(**kwargs) + self.config = config + self.expand_output = expand_output + + self.conv_stem = TFMobileViTConvLayer( + config, + in_channels=config.num_channels, + out_channels=config.neck_hidden_sizes[0], + kernel_size=3, + stride=2, + name="conv_stem", + ) + + self.encoder = TFMobileViTEncoder(config, name="encoder") + + if self.expand_output: + self.conv_1x1_exp = TFMobileViTConvLayer( + config, + in_channels=config.neck_hidden_sizes[5], + out_channels=config.neck_hidden_sizes[6], + kernel_size=1, + name="conv_1x1_exp", + ) + + self.pooler = tf.keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + embedding_output = self.conv_stem(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.expand_output: + last_hidden_state = self.conv_1x1_exp(encoder_outputs[0]) + + # Change to NCHW output format to have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) + + # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels) + pooled_output = self.pooler(last_hidden_state) + else: + last_hidden_state = encoder_outputs[0] + # Change to NCHW output format to have uniformity in the modules + last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2]) + pooled_output = None + + if not return_dict: + output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) + return output + encoder_outputs[1:] + + # Change the other hidden state outputs to NCHW as well + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) + + return TFBaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + ) + + +class TFMobileViTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MobileViTConfig + base_model_prefix = "mobilevit" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=( + 3, + self.config.num_channels, + self.config.image_size, + self.config.image_size, + ), + dtype=tf.float32, + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + return self.serving_output(output) + + +MOBILEVIT_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the + tensors in the first argument of the model call function: `model(inputs)`. + + + + Parameters: + config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MOBILEVIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`MobileViTFeatureExtractor`]. See + [`MobileViTFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. +""" + + +@add_start_docstrings( + "The bare MobileViT model outputting raw hidden-states without any specific head on top.", + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTModel(TFMobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + self.expand_output = expand_output + + self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit") + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]: + + output = self.mobilevit(pixel_values, output_hidden_states, return_dict) + return output + + def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + # hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions + return TFBaseModelOutputWithPooling( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + hidden_states=output.hidden_states, + ) + + +@add_start_docstrings( + """ + MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit") + + # Classifier head + self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob) + self.classifier = ( + tf.keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[tf.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(self.dropout(pooled_output)) + + logits = self.classifier(pooled_output) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) From db3ac6dbb676ce8de284fe3713c75385fd1b81eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 17 Aug 2022 09:21:18 +0530 Subject: [PATCH 03/20] add: initial implementation that passes intg tests. Co-authored-by: Amy --- .../models/mobilevit/modeling_tf_mobilevit.py | 247 +++++++++++++++++- 1 file changed, 246 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 568da5be92d7c..f200a979475ae 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -734,7 +734,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ VISION_DUMMY_INPUTS = tf.random.uniform( shape=( - 3, + 1, # TODO: change to 3 later (sayakpaul). self.config.num_channels, self.config.image_size, self.config.image_size, @@ -908,3 +908,248 @@ def call( logits=logits, hidden_states=outputs.hidden_states, ) + + +class TFMobileViTASPPPooling(tf.keras.layers.Layer): + def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None: + super().__init__(**kwargs) + + self.global_pool = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool") + + self.conv_1x1 = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + use_normalization=True, + use_activation="relu", + name="conv_1x1", + ) + + def call(self, features: tf.Tensor) -> tf.Tensor: + spatial_size = shape_list(features)[1:-1] + features = self.global_pool(features) + features = self.conv_1x1(features) + features = tf.image.resize(features, size=spatial_size, method="bilinear") + return features + + +class TFMobileViTASPP(tf.keras.layers.Layer): + """ + ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587 + """ + + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + + in_channels = config.neck_hidden_sizes[-2] + out_channels = config.aspp_out_channels + + if len(config.atrous_rates) != 3: + raise ValueError("Expected 3 values for atrous_rates") + + self.convs = [] + + in_projection = TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + name="convs.0", + ) + self.convs.append(in_projection) + + self.convs.extend( + [ + TFMobileViTConvLayer( + config, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=rate, + use_activation="relu", + name=f"convs.{i + 1}", + ) + for i, rate in enumerate(config.atrous_rates) + ] + ) + + pool_layer = TFMobileViTASPPPooling( + config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}" + ) + self.convs.append(pool_layer) + + self.project = TFMobileViTConvLayer( + config, + in_channels=5 * out_channels, + out_channels=out_channels, + kernel_size=1, + use_activation="relu", + name="project", + ) + + self.dropout = tf.keras.layers.Dropout(config.aspp_dropout_prob) + + def call(self, features: tf.Tensor) -> tf.Tensor: + # since the hidden states were transposed to have `(batch_size, channels, height, width)` + # layout. + features = tf.transpose(features, perm=[0, 2, 3, 1]) + pyramid = [] + for conv in self.convs: + # print(f"From TFMobileViTASPP: {conv(features).shape}") + pyramid.append(conv(features)) + pyramid = tf.concat(pyramid, axis=-1) + + # print(f"From TFMobileViTASPP first convolution: {self.convs[0].convolution.kernel.shape}") + + pooled_features = self.project(pyramid) + pooled_features = self.dropout(pooled_features) + return pooled_features + + +class TFMobileViTDeepLabV3(tf.keras.layers.Layer): + """ + DeepLabv3 architecture: https://arxiv.org/abs/1706.05587 + """ + + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.aspp = TFMobileViTASPP(config, name="aspp") + + self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob) + + self.classifier = TFMobileViTConvLayer( + config, + in_channels=config.aspp_out_channels, + out_channels=config.num_labels, + kernel_size=1, + use_normalization=False, + use_activation=False, + bias=True, + name="classifier", + ) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + features = self.aspp(hidden_states[-1]) + # print(f"From MobileViTDeepLabV3: {features.shape}") + features = self.dropout(features) + features = self.classifier(features) + return features + + +@add_start_docstrings( + """ + MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC. + """, + MOBILEVIT_START_DOCSTRING, +) +class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel): + def __init__(self, config: MobileViTConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.num_labels = config.num_labels + self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit") + self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head") + + def hf_compute_loss(self, logits, labels): + # upsample logits to the images' original size + # `labels` is of shape (batch_size, height, width) + label_interp_shape = shape_list(labels)[1:] + + upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") + # compute weighted loss + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") + + def masked_loss(real, pred): + unmasked_loss = loss_fct(real, pred) + mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype) + masked_loss = unmasked_loss * mask + # Reduction strategy in the similar spirit with + # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210 + reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask) + return tf.reshape(reduced_masked_loss, (1,)) + + return masked_loss(labels, upsampled_logits) + + @unpack_inputs + @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, TFSemanticSegmenterOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> from transformers import MobileViTFeatureExtractor, TFMobileViTForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small") + >>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + + >>> outputs = model(**inputs) + + >>> # logits are of shape (batch_size, num_labels, height, width) + >>> logits = outputs.logits + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mobilevit( + pixel_values, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] + + logits = self.segmentation_head(encoder_hidden_states) + + loss = None + if labels is not None: + if not self.config.num_labels > 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.hf_compute_loss(logits=logits, labels=labels) + + # make logits of shape (batch_size, num_labels, height, width) to + # keep them consistent across APIs + logits = tf.transpose(logits, perm=[0, 3, 1, 2]) + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput: + # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions + return TFSemanticSegmenterOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=None) From f08fdaa0a22806466e49c063274afc9f7fb5725a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 17 Aug 2022 09:22:48 +0530 Subject: [PATCH 04/20] chore: formatting. --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index f200a979475ae..713c9217af01a 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -16,7 +16,6 @@ # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE """ TensorFlow 2.0 MobileViT model.""" -import math from typing import Dict, Optional, Tuple, Union import tensorflow as tf @@ -734,7 +733,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ VISION_DUMMY_INPUTS = tf.random.uniform( shape=( - 1, # TODO: change to 3 later (sayakpaul). + 1, # TODO: change to 3 later (sayakpaul). self.config.num_channels, self.config.image_size, self.config.image_size, @@ -993,7 +992,7 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: self.dropout = tf.keras.layers.Dropout(config.aspp_dropout_prob) def call(self, features: tf.Tensor) -> tf.Tensor: - # since the hidden states were transposed to have `(batch_size, channels, height, width)` + # since the hidden states were transposed to have `(batch_size, channels, height, width)` # layout. features = tf.transpose(features, perm=[0, 2, 3, 1]) pyramid = [] From 8569374c525e4f521e70338f5893cb2ffbc6a88a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 17 Aug 2022 11:19:02 +0530 Subject: [PATCH 05/20] add: tests (still breaking because of config mismatch). Coo-authored-by: Yih <2521628+ydshieh@users.noreply.github.com> --- src/transformers/modeling_tf_utils.py | 2 + .../models/mobilevit/modeling_tf_mobilevit.py | 2 +- .../mobilevit/test_modeling_tf_mobilevit.py | 366 ++++++++++++++++++ 3 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 tests/models/mobilevit/test_modeling_tf_mobilevit.py diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 68ee4117a2f9d..f66b61bdf30f9 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1017,6 +1017,8 @@ def framework(self) -> str: def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + print(f"From modeling_tf_utils: {config}.") + print(f"From modeling_tf_utils: is it PretrainedConfig instance: {isinstance(config, PretrainedConfig)}") if not isinstance(config, PretrainedConfig): raise ValueError( f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 713c9217af01a..d5806ce5c751e 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -733,7 +733,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ VISION_DUMMY_INPUTS = tf.random.uniform( shape=( - 1, # TODO: change to 3 later (sayakpaul). + 3, self.config.num_channels, self.config.image_size, self.config.image_size, diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py new file mode 100644 index 0000000000000..6347407bee4df --- /dev/null +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow MobileViT model. """ + + +import inspect +import unittest + +from transformers import MobileViTConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.file_utils import is_tf_available, is_vision_available +from transformers.testing_utils import require_tf, slow + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import numpy as np + import tensorflow as tf + + from src.transformers.models.mobilevit.modeling_tf_mobilevit import TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation, TFMobileViTModel + from src.transformers.models.mobilevit.modeling_tf_mobilevit import TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import MobileViTFeatureExtractor + + +class TFMobileViTConfigTester(ConfigTester): + def create_and_test_config_common_properties(self): + config = self.config_class(**self.inputs_dict) + self.parent.assertTrue(hasattr(config, "hidden_sizes")) + self.parent.assertTrue(hasattr(config, "neck_hidden_sizes")) + self.parent.assertTrue(hasattr(config, "num_attention_heads")) + + +class TFMobileViTModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=2, + num_channels=3, + last_hidden_size=640, + num_attention_heads=4, + hidden_act="silu", + conv_kernel_size=3, + output_stride=32, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + classifier_dropout_prob=0.1, + initializer_range=0.02, + is_training=True, + use_labels=True, + num_labels=10, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.last_hidden_size = last_hidden_size + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.conv_kernel_size = conv_kernel_size + self.output_stride = output_stride + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.classifier_dropout_prob = classifier_dropout_prob + self.use_labels = use_labels + self.is_training = is_training + self.num_labels = num_labels + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + pixel_labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels, pixel_labels + + def get_config(self): + config = MobileViTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + num_attention_heads=self.num_attention_heads, + hidden_act=self.hidden_act, + conv_kernel_size=self.conv_kernel_size, + output_stride=self.output_stride, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + classifier_dropout_prob=self.classifier_dropout_prob, + initializer_range=self.initializer_range, + ) + print(f"From test class: is it PretrainedConfig instance: {isinstance(config, PretrainedConfig)}") + print(f"From test class: {config}.") + return config + + def create_and_check_model(self, config, pixel_values, labels, pixel_labels): + model = TFMobileViTModel(config=config) + result = model(pixel_values, training=False) + expected_height = expected_width = self.image_size // self.output_stride + self.parent.assertEqual( + result.last_hidden_state.shape, + ( + self.batch_size, + self.last_hidden_size, + expected_height, + expected_width + ) + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): + config.num_labels = self.num_labels + model = TFMobileViTForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels): + config.num_labels = self.num_labels + model = TFMobileViTForSemanticSegmentation(config) + expected_height = expected_width = self.image_size // self.output_stride + + result = model(pixel_values, training=False) + self.parent.assertEqual( + result.logits.shape, + ( + self.batch_size, + self.num_labels, + expected_height, + expected_width + ) + ) + + result = model(pixel_values, labels=pixel_labels, training=False) + self.parent.assertEqual( + result.logits.shape, + ( + self.batch_size, + self.num_labels, + expected_height, + expected_width + ) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels, pixel_labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + (TFMobileViTModel, TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation) + if is_tf_available() + else () + ) + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + test_onnx = False + + def setUp(self): + self.model_tester = TFMobileViTModelTester(self) + self.config_tester = TFMobileViTConfigTester(self, config_class=MobileViTConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="MobileViT does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="MobileViT does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="MobileViT does not output attentions") + def test_attention_outputs(self): + pass + + @unittest.skip("Test was written for TF 1.x and isn't really relevant here") + def test_compile_tf_model(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + print(f"From test_model: {isinstance(config_and_inputs[0], PretrainedConfig)}") + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_stages = 5 + self.assertEqual(len(hidden_states), expected_num_stages) + + # MobileViT's feature maps are of shape (batch_size, num_channels, height, width) + # with the width and height being successively divided by 2. + divisor = 2 + for i in range(len(hidden_states)): + self.assertListEqual( + list(hidden_states[i].shape[-2:]), + [self.model_tester.image_size // divisor, self.model_tester.image_size // divisor], + ) + divisor *= 2 + + self.assertEqual(self.model_tester.output_stride, divisor // 2) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + def test_for_semantic_segmentation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs) + + @unittest.skipIf( + not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, + reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + ) + def test_dataset_conversion(self): + super().test_dataset_conversion() + + def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1): + self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol)) + + @unittest.skipIf( + not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, + reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + ) + def test_keras_fit(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # Since `TFMobileViTModel` cannot operate with the default `fit()` method. + if model_class.__name__ != "TFMobileViTModel": + model = model_class(config) + if getattr(model, "hf_compute_loss", None): + super().test_keras_fit() + + @slow + def test_model_from_pretrained(self): + for model_name in TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + # `from_pt` will be removed. + model = TFMobileViTModel.from_pretrained(model_name, from_pt=True) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +class TFMobileViTModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_image_classification_head(self): + # `from_pt` will be removed + model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small", from_pt=True) + + feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small") + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs, training=False) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([-1.9364, -1.2327, -0.4653]) + + tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4, rtol=1e-04) + + @slow + def test_inference_semantic_segmentation(self): + # `from_pt` will be removed + model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small", from_pt=True) + + feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small") + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(inputs.pixel_values, training=False) + logits = outputs.logits + + # verify the logits + expected_shape = tf.TensorShape((1, 21, 32, 32)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = tf.constant( + [ + [[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]], + [[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]], + [[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]], + ] + ) + + tf.debugging.assert_near(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4) From 6fcc70f98f9294bb059fd06541a63ed106ad3f48 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Aug 2022 12:36:02 +0530 Subject: [PATCH 06/20] add: corrected tests and remaning changes. --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/mobilevit.mdx | 40 +++++- src/transformers/__init__.py | 14 ++ src/transformers/modeling_tf_utils.py | 2 - .../models/auto/modeling_tf_auto.py | 3 + src/transformers/models/mobilevit/__init__.py | 44 ++++++- .../models/mobilevit/modeling_tf_mobilevit.py | 27 ++-- src/transformers/utils/dummy_tf_objects.py | 31 +++++ .../mobilevit/test_modeling_tf_mobilevit.py | 123 +++++++++++++----- utils/documentation_tests.txt | 1 + 10 files changed, 240 insertions(+), 47 deletions(-) diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 5c0d51d8b7afb..95ece95eff3e3 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -257,7 +257,7 @@ Flax), PyTorch, and/or TensorFlow. | mBART | ✅ | ✅ | ✅ | ✅ | ✅ | | Megatron-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | | MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| MobileViT | ❌ | ❌ | ✅ | ❌ | ❌ | +| MobileViT | ❌ | ❌ | ✅ | ✅ | ❌ | | MPNet | ✅ | ✅ | ✅ | ✅ | ❌ | | MT5 | ✅ | ✅ | ✅ | ✅ | ✅ | | MVP | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index f5fd403fd59ed..fd35d89267e67 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -22,12 +22,33 @@ The abstract from the paper is the following: Tips: -- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map. +- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map. You can follow [this tutorial](https://keras.io/examples/vision/mobilevit) for a lightweight introduction. - One can use [`MobileViTFeatureExtractor`] to prepare images for the model. Note that if you do your own preprocessing, the pretrained checkpoints expect images to be in BGR pixel order (not RGB). - The available image classification checkpoints are pre-trained on [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k) (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes). - The segmentation model uses a [DeepLabV3](https://arxiv.org/abs/1706.05587) head. The available semantic segmentation checkpoints are pre-trained on [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/). +- As the name suggests MobileViT was desgined to be performant and efficient on mobile phones. The TensorFlow versions of the MobileViT models are fully compatible with [TensorFlow Lite](https://www.tensorflow.org/lite). -This model was contributed by [matthijs](https://huggingface.co/Matthijs). The original code and weights can be found [here](https://github.com/apple/ml-cvnets). + You can use the following code to convert a MobileViT checkpoint (be it image classification or semantic segmentation) to generate a + TensorFlow Lite model: + + ```py + from transformers import TFMobileViTForImageClassification + import tensorflow as tf + + model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small") + + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + tflite_model = converter.convert() + with open("mobilevit_xxs.tflite", "wb") as f: + f.write(tflite_model) + ``` + + The resulting model will be just **about an MB** making it a perfect fit for mobile applications where resources and network + bandwidth can be constrained. + + +This model was contributed by [matthijs](https://huggingface.co/Matthijs). The TensorFlow version of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul) The original code and weights can be found [here](https://github.com/apple/ml-cvnets). ## MobileViTConfig @@ -53,3 +74,18 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The o [[autodoc]] MobileViTForSemanticSegmentation - forward + +## TFMobileViTModel + +[[autodoc]] TFMobileViTModel + - call + +## TFMobileViTForImageClassification + +[[autodoc]] TFMobileViTForImageClassification + - call + +## TFMobileViTForSemanticSegmentation + +[[autodoc]] TFMobileViTForSemanticSegmentation + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index be2be2727f014..2e6203321dc3e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2376,6 +2376,15 @@ "TFMobileBertPreTrainedModel", ] ) + _import_structure["models.mobilevit"].extend( + [ + "TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFMobileViTPreTrainedModel", + "TFMobileViTModel", + "TFMobileViTForImageClassification", + "TFMobileViTForSemanticSegmentation", + ] + ) _import_structure["models.mpnet"].extend( [ "TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4803,6 +4812,7 @@ from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel from .models.mobilebert import ( TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST, TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, TFMobileBertForNextSentencePrediction, @@ -4813,6 +4823,10 @@ TFMobileBertMainLayer, TFMobileBertModel, TFMobileBertPreTrainedModel, + TFMobileViTForImageClassification, + TFMobileViTForSemanticSegmentation, + TFMobileViTModel, + TFMobileViTPreTrainedModel, ) from .models.mpnet import ( TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index f66b61bdf30f9..68ee4117a2f9d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1017,8 +1017,6 @@ def framework(self) -> str: def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) - print(f"From modeling_tf_utils: {config}.") - print(f"From modeling_tf_utils: is it PretrainedConfig instance: {isinstance(config, PretrainedConfig)}") if not isinstance(config, PretrainedConfig): raise ValueError( f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 6f9b15c131d60..8046d02acf8c3 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -58,6 +58,7 @@ ("marian", "TFMarianModel"), ("mbart", "TFMBartModel"), ("mobilebert", "TFMobileBertModel"), + ("mobilevit", "TFMobileViTModel"), ("mpnet", "TFMPNetModel"), ("mt5", "TFMT5Model"), ("openai-gpt", "TFOpenAIGPTModel"), @@ -179,6 +180,7 @@ ("convnext", "TFConvNextForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), + ("mobilevit", "TFMobileViTForImageClassification"), ("regnet", "TFRegNetForImageClassification"), ("resnet", "TFResNetForImageClassification"), ("segformer", "TFSegformerForImageClassification"), @@ -191,6 +193,7 @@ [ # Model for Semantic Segmentation mapping ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), + ("mobilevit", "TFMobileViTForSemanticSegmentation"), ("segformer", "TFSegformerForSemanticSegmentation"), ] ) diff --git a/src/transformers/models/mobilevit/__init__.py b/src/transformers/models/mobilevit/__init__.py index cd639f50323c4..00931ab57dc6b 100644 --- a/src/transformers/models/mobilevit/__init__.py +++ b/src/transformers/models/mobilevit/__init__.py @@ -15,9 +15,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from operator import is_ from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = { @@ -46,6 +53,19 @@ "MobileViTPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_mobilevit"] = [ + "TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFMobileViTForImageClassification", + "TFMobileViTForSemanticSegmentation", + "TFMobileViTModel", + "TFMobileViTPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig, MobileViTOnnxConfig @@ -72,6 +92,28 @@ MobileViTPreTrainedModel, ) + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_mobilevit import MobileViTFeatureExtractor + + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_mobilevit import ( + TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFMobileViTForImageClassification, + TFMobileViTForSemanticSegmentation, + TFMobileViTModel, + TFMobileViTPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index d5806ce5c751e..3a9290ea0b906 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -129,7 +129,8 @@ def __init__( self.activation = None def call(self, features: tf.Tensor) -> tf.Tensor: - features = self.convolution(self.padding(features)) + padded_features = self.padding(features) + features = self.convolution(padded_features) if self.normalization is not None: features = self.normalization(features) if self.activation is not None: @@ -700,7 +701,17 @@ def call( if not return_dict: output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,) - return output + encoder_outputs[1:] + + # Change to NCHW output format to have uniformity in the modules + if not self.expand_output: + remaining_encoder_outputs = encoder_outputs[1:] + remaining_encoder_outputs = tuple( + [tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0]] + ) + remaining_encoder_outputs = (remaining_encoder_outputs,) + return output + remaining_encoder_outputs + else: + return output + encoder_outputs[1:] # Change the other hidden state outputs to NCHW as well if output_hidden_states: @@ -733,7 +744,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ VISION_DUMMY_INPUTS = tf.random.uniform( shape=( - 3, + 3, self.config.num_channels, self.config.image_size, self.config.image_size, @@ -908,6 +919,10 @@ def call( hidden_states=outputs.hidden_states, ) + def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention: + # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions + return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states) + class TFMobileViTASPPPooling(tf.keras.layers.Layer): def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None: @@ -993,16 +1008,13 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: def call(self, features: tf.Tensor) -> tf.Tensor: # since the hidden states were transposed to have `(batch_size, channels, height, width)` - # layout. + # layout we transpose them back to have `(batch_size, height, width, channels)` layout. features = tf.transpose(features, perm=[0, 2, 3, 1]) pyramid = [] for conv in self.convs: - # print(f"From TFMobileViTASPP: {conv(features).shape}") pyramid.append(conv(features)) pyramid = tf.concat(pyramid, axis=-1) - # print(f"From TFMobileViTASPP first convolution: {self.convs[0].convolution.kernel.shape}") - pooled_features = self.project(pyramid) pooled_features = self.dropout(pooled_features) return pooled_features @@ -1032,7 +1044,6 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: def call(self, hidden_states: tf.Tensor) -> tf.Tensor: features = self.aspp(hidden_states[-1]) - # print(f"From MobileViTDeepLabV3: {features.shape}") features = self.dropout(features) features = self.classifier(features) return features diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 6df601ca646af..cc75015efb0b8 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1542,6 +1542,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFMobileViTForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMobileViTForSemanticSegmentation(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMobileViTModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMobileViTPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py index 6347407bee4df..330cacfb6c449 100644 --- a/tests/models/mobilevit/test_modeling_tf_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -31,8 +31,8 @@ import numpy as np import tensorflow as tf - from src.transformers.models.mobilevit.modeling_tf_mobilevit import TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation, TFMobileViTModel - from src.transformers.models.mobilevit.modeling_tf_mobilevit import TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers import TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation, TFMobileViTModel + from transformers.models.mobilevit.modeling_tf_mobilevit import TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -104,7 +104,7 @@ def prepare_config_and_inputs(self): return config, pixel_values, labels, pixel_labels def get_config(self): - config = MobileViTConfig( + return MobileViTConfig( image_size=self.image_size, patch_size=self.patch_size, num_channels=self.num_channels, @@ -117,22 +117,13 @@ def get_config(self): classifier_dropout_prob=self.classifier_dropout_prob, initializer_range=self.initializer_range, ) - print(f"From test class: is it PretrainedConfig instance: {isinstance(config, PretrainedConfig)}") - print(f"From test class: {config}.") - return config def create_and_check_model(self, config, pixel_values, labels, pixel_labels): model = TFMobileViTModel(config=config) result = model(pixel_values, training=False) expected_height = expected_width = self.image_size // self.output_stride self.parent.assertEqual( - result.last_hidden_state.shape, - ( - self.batch_size, - self.last_hidden_size, - expected_height, - expected_width - ) + result.last_hidden_state.shape, (self.batch_size, self.last_hidden_size, expected_height, expected_width) ) def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): @@ -145,27 +136,15 @@ def create_and_check_for_semantic_segmentation(self, config, pixel_values, label config.num_labels = self.num_labels model = TFMobileViTForSemanticSegmentation(config) expected_height = expected_width = self.image_size // self.output_stride - + result = model(pixel_values, training=False) self.parent.assertEqual( - result.logits.shape, - ( - self.batch_size, - self.num_labels, - expected_height, - expected_width - ) + result.logits.shape, (self.batch_size, self.num_labels, expected_height, expected_width) ) - + result = model(pixel_values, labels=pixel_labels, training=False) self.parent.assertEqual( - result.logits.shape, - ( - self.batch_size, - self.num_labels, - expected_height, - expected_width - ) + result.logits.shape, (self.batch_size, self.num_labels, expected_height, expected_width) ) def prepare_config_and_inputs_for_common(self): @@ -222,7 +201,7 @@ def test_forward_signature(self): for model_class in self.all_model_classes: model = model_class(config) - signature = inspect.signature(model.forward) + signature = inspect.signature(model.call) # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] @@ -231,7 +210,6 @@ def test_forward_signature(self): def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - print(f"From test_model: {isinstance(config_and_inputs[0], PretrainedConfig)}") self.model_tester.create_and_check_model(*config_and_inputs) def test_hidden_states_output(self): @@ -301,11 +279,90 @@ def test_keras_fit(self): if getattr(model, "hf_compute_loss", None): super().test_keras_fit() + # The default test_loss_computation() uses -100 as a proxy ignore_index + # to test masked losses. Overridding to avoid -100 since semantic segmentation + # models use `semantic_loss_ignore_index` from the config. + def test_loss_computation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + # set an ignore index to correctly test the masked loss used in + # `TFMobileViTForSemanticSegmentation`. + if model_class.__name__ != "TFMobileViTForSemanticSegmentation": + config.semantic_loss_ignore_index = 5 + + model = model_class(config) + if getattr(model, "hf_compute_loss", None): + # The number of elements in the loss should be the same as the number of elements in the label + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + added_label = prepared_for_class[ + sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] + ] + expected_loss_size = added_label.shape.as_list()[:1] + + # Test that model correctly compute the loss with kwargs + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + possible_input_names = {"input_ids", "pixel_values", "input_features"} + input_name = possible_input_names.intersection(set(prepared_for_class)).pop() + model_input = prepared_for_class.pop(input_name) + + loss = model(model_input, **prepared_for_class)[0] + self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) + + # Test that model correctly compute the loss when we mask some positions + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + possible_input_names = {"input_ids", "pixel_values", "input_features"} + input_name = possible_input_names.intersection(set(prepared_for_class)).pop() + model_input = prepared_for_class.pop(input_name) + if "labels" in prepared_for_class: + labels = prepared_for_class["labels"].numpy() + if len(labels.shape) > 1 and labels.shape[1] != 1: + # labels[0] = -100 + prepared_for_class["labels"] = tf.convert_to_tensor(labels) + loss = model(model_input, **prepared_for_class)[0] + self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) + self.assertTrue(not np.any(np.isnan(loss.numpy()))) + + # Test that model correctly compute the loss with a dict + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + loss = model(prepared_for_class)[0] + self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) + + # Test that model correctly compute the loss with a tuple + prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) + + # Get keys that were added with the _prepare_for_class function + label_keys = prepared_for_class.keys() - inputs_dict.keys() + signature = inspect.signature(model.call).parameters + signature_names = list(signature.keys()) + + # Create a dictionary holding the location of the tensors in the tuple + tuple_index_mapping = {0: input_name} + for label_key in label_keys: + label_key_index = signature_names.index(label_key) + tuple_index_mapping[label_key_index] = label_key + sorted_tuple_index_mapping = sorted(tuple_index_mapping.items()) + # Initialize a list with their default values, update the values and convert to a tuple + list_input = [] + + for name in signature_names: + if name != "kwargs": + list_input.append(signature[name].default) + + for index, value in sorted_tuple_index_mapping: + list_input[index] = prepared_for_class[value] + + tuple_input = tuple(list_input) + + # Send to model + loss = model(tuple_input[:-1])[0] + + self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) + @slow def test_model_from_pretrained(self): for model_name in TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: # `from_pt` will be removed. - model = TFMobileViTModel.from_pretrained(model_name, from_pt=True) + model = TFMobileViTModel.from_pretrained(model_name, from_pt=True) self.assertIsNotNone(model) @@ -320,7 +377,7 @@ class TFMobileViTModelIntegrationTest(unittest.TestCase): @slow def test_inference_image_classification_head(self): # `from_pt` will be removed - model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small", from_pt=True) + model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small", from_pt=True) feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small") image = prepare_img() diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 1941a7343a6bc..304cbf42b5fba 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -45,6 +45,7 @@ src/transformers/models/mbart/modeling_mbart.py src/transformers/models/mobilebert/modeling_mobilebert.py src/transformers/models/mobilebert/modeling_tf_mobilebert.py src/transformers/models/mobilevit/modeling_mobilevit.py +src/transformers/models/mobilevit/modeling_tf_mobilevit.py src/transformers/models/opt/modeling_opt.py src/transformers/models/opt/modeling_tf_opt.py src/transformers/models/owlvit/modeling_owlvit.py From cd72a53fc64ea651167878f843a0ea6cbe71da77 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Aug 2022 12:44:12 +0530 Subject: [PATCH 07/20] fix code style and repo consistency. --- src/transformers/utils/dummy_tf_objects.py | 6 +++--- tests/models/mobilevit/test_modeling_tf_mobilevit.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index cc75015efb0b8..8572f59a35984 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1472,6 +1472,9 @@ def __init__(self, *args, **kwargs): TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None +TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + class TFMobileBertForMaskedLM(metaclass=DummyObject): _backends = ["tf"] @@ -1542,9 +1545,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) -TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None - - class TFMobileViTForImageClassification(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py index 330cacfb6c449..1e035ac7303c9 100644 --- a/tests/models/mobilevit/test_modeling_tf_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -19,7 +19,6 @@ import unittest from transformers import MobileViTConfig -from transformers.configuration_utils import PretrainedConfig from transformers.file_utils import is_tf_available, is_vision_available from transformers.testing_utils import require_tf, slow From cc634b724fac298c9cef1556eda5428bfc9f0d58 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 25 Aug 2022 10:08:35 +0530 Subject: [PATCH 08/20] address PR comments. --- src/transformers/models/mobilevit/__init__.py | 1 - .../models/mobilevit/modeling_tf_mobilevit.py | 143 +++++++++--------- .../mobilevit/test_modeling_tf_mobilevit.py | 4 +- 3 files changed, 75 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/mobilevit/__init__.py b/src/transformers/models/mobilevit/__init__.py index 00931ab57dc6b..e1e088f693ba2 100644 --- a/src/transformers/models/mobilevit/__init__.py +++ b/src/transformers/models/mobilevit/__init__.py @@ -15,7 +15,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from operator import is_ from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 3a9290ea0b906..924aa6a56a973 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -84,7 +84,6 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer): def __init__( self, config: MobileViTConfig, - in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, @@ -96,6 +95,11 @@ def __init__( **kwargs ) -> None: super().__init__(**kwargs) + logger.warning( + f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " + "to train/fine-tine this model, you need a GPU or a TPU" + ) + padding = int((kernel_size - 1) / 2) * dilation self.padding = tf.keras.layers.ZeroPadding2D(padding) @@ -128,11 +132,11 @@ def __init__( else: self.activation = None - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: padded_features = self.padding(features) features = self.convolution(padded_features) if self.normalization is not None: - features = self.normalization(features) + features = self.normalization(features, training=training) if self.activation is not None: features = self.activation(features) return features @@ -155,12 +159,11 @@ def __init__( self.use_residual = (stride == 1) and (in_channels == out_channels) self.expand_1x1 = TFMobileViTConvLayer( - config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1" + config, out_channels=expanded_channels, kernel_size=1, name="expand_1x1" ) self.conv_3x3 = TFMobileViTConvLayer( config, - in_channels=expanded_channels, out_channels=expanded_channels, kernel_size=3, stride=stride, @@ -171,19 +174,18 @@ def __init__( self.reduce_1x1 = TFMobileViTConvLayer( config, - in_channels=expanded_channels, out_channels=out_channels, kernel_size=1, use_activation=False, name="reduce_1x1", ) - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: residual = features - features = self.expand_1x1(features) - features = self.conv_3x3(features) - features = self.reduce_1x1(features) + features = self.expand_1x1(features, training=training) + features = self.conv_3x3(features, training=training) + features = self.reduce_1x1(features, training=training) return residual + features if self.use_residual else features @@ -212,9 +214,9 @@ def __init__( self.layers.append(layer) in_channels = out_channels - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: for layer_module in self.layers: - features = layer_module(features) + features = layer_module(features, training=training) return features @@ -245,7 +247,7 @@ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) return tf.transpose(x, perm=[0, 2, 1, 3]) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size = shape_list(hidden_states)[0] key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -261,7 +263,7 @@ def call(self, hidden_states: tf.Tensor) -> tf.Tensor: # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) + attention_probs = self.dropout(attention_probs, training=training) context_layer = tf.matmul(attention_probs, value_layer) @@ -276,9 +278,9 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: self.dense = tf.keras.layers.Dense(hidden_size, name="dense") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) return hidden_states @@ -288,9 +290,12 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention") self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output") - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - self_outputs = self.attention(hidden_states) - attention_output = self.dense_output(self_outputs) + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + self_outputs = self.attention(hidden_states, training=training) + attention_output = self.dense_output(self_outputs, training=training) return attention_output @@ -315,9 +320,9 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: self.dense = tf.keras.layers.Dense(hidden_size, name="dense") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) - def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = hidden_states + input_tensor return hidden_states @@ -335,13 +340,13 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: epsilon=config.layer_norm_eps, name="layernorm_after" ) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - attention_output = self.attention(self.layernorm_before(hidden_states)) + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states), training=training) hidden_states = attention_output + hidden_states layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) - layer_output = self.mobilevit_output(layer_output, hidden_states) + layer_output = self.mobilevit_output(layer_output, hidden_states, training=training) return layer_output @@ -359,9 +364,9 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, * ) self.layers.append(transformer_layer) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: for layer_module in self.layers: - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states, training=training) return hidden_states @@ -400,7 +405,6 @@ def __init__( self.conv_kxk = TFMobileViTConvLayer( config, - in_channels=in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="conv_kxk", @@ -408,7 +412,6 @@ def __init__( self.conv_1x1 = TFMobileViTConvLayer( config, - in_channels=in_channels, out_channels=hidden_size, kernel_size=1, use_normalization=False, @@ -423,12 +426,11 @@ def __init__( self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") self.conv_projection = TFMobileViTConvLayer( - config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection" + config, out_channels=in_channels, kernel_size=1, name="conv_projection" ) self.fusion = TFMobileViTConvLayer( config, - in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion", @@ -504,29 +506,29 @@ def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor: return features - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: # reduce spatial dimensions if needed if self.downsampling_layer: - features = self.downsampling_layer(features) + features = self.downsampling_layer(features, training=training) residual = features # local representation - features = self.conv_kxk(features) - features = self.conv_1x1(features) + features = self.conv_kxk(features, training=training) + features = self.conv_1x1(features, training=training) # convert feature map to patches patches, info_dict = self.unfolding(features) # learn global representations - patches = self.transformer(patches) + patches = self.transformer(patches, training=training) patches = self.layernorm(patches) # convert patches back to feature maps features = self.folding(patches, info_dict) - features = self.conv_projection(features) - features = self.fusion(tf.concat([residual, features], axis=-1)) + features = self.conv_projection(features, training=training) + features = self.fusion(tf.concat([residual, features], axis=-1), training=training) return features @@ -614,11 +616,12 @@ def call( hidden_states: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True, + training: bool = False, ) -> Union[tuple, TFBaseModelOutput]: all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.layers): - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states, training=training) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -640,7 +643,6 @@ def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs self.conv_stem = TFMobileViTConvLayer( config, - in_channels=config.num_channels, out_channels=config.neck_hidden_sizes[0], kernel_size=3, stride=2, @@ -652,7 +654,6 @@ def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs if self.expand_output: self.conv_1x1_exp = TFMobileViTConvLayer( config, - in_channels=config.neck_hidden_sizes[5], out_channels=config.neck_hidden_sizes[6], kernel_size=1, name="conv_1x1_exp", @@ -660,12 +661,20 @@ def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs self.pooler = tf.keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler") + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + @unpack_inputs def call( self, pixel_values: Optional[tf.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + training: bool = False, ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -677,12 +686,10 @@ def call( # shape = (batch_size, in_height, in_width, in_channels=num_channels) pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) - embedding_output = self.conv_stem(pixel_values) + embedding_output = self.conv_stem(pixel_values, training=training) encoder_outputs = self.encoder( - embedding_output, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training ) if self.expand_output: @@ -842,9 +849,10 @@ def call( pixel_values: Optional[tf.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + training: bool = False, ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]: - output = self.mobilevit(pixel_values, output_hidden_states, return_dict) + output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training) return output def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: @@ -891,6 +899,7 @@ def call( output_hidden_states: Optional[bool] = None, labels: Optional[tf.Tensor] = None, return_dict: Optional[bool] = None, + training: Optional[bool] = False, ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]: r""" labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): @@ -900,13 +909,13 @@ def call( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + outputs = self.mobilevit( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) pooled_output = outputs.pooler_output if return_dict else outputs[1] - logits = self.classifier(self.dropout(pooled_output)) - - logits = self.classifier(pooled_output) + logits = self.classifier(self.dropout(pooled_output, training=training)) loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) if not return_dict: @@ -925,14 +934,13 @@ def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFIm class TFMobileViTASPPPooling(tf.keras.layers.Layer): - def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None: + def __init__(self, config: MobileViTConfig, out_channels: int, **kwargs) -> None: super().__init__(**kwargs) self.global_pool = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool") self.conv_1x1 = TFMobileViTConvLayer( config, - in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, @@ -941,10 +949,10 @@ def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, name="conv_1x1", ) - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: spatial_size = shape_list(features)[1:-1] features = self.global_pool(features) - features = self.conv_1x1(features) + features = self.conv_1x1(features, training=training) features = tf.image.resize(features, size=spatial_size, method="bilinear") return features @@ -957,7 +965,6 @@ class TFMobileViTASPP(tf.keras.layers.Layer): def __init__(self, config: MobileViTConfig, **kwargs) -> None: super().__init__(**kwargs) - in_channels = config.neck_hidden_sizes[-2] out_channels = config.aspp_out_channels if len(config.atrous_rates) != 3: @@ -967,7 +974,6 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: in_projection = TFMobileViTConvLayer( config, - in_channels=in_channels, out_channels=out_channels, kernel_size=1, use_activation="relu", @@ -979,7 +985,6 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: [ TFMobileViTConvLayer( config, - in_channels=in_channels, out_channels=out_channels, kernel_size=3, dilation=rate, @@ -990,14 +995,11 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: ] ) - pool_layer = TFMobileViTASPPPooling( - config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}" - ) + pool_layer = TFMobileViTASPPPooling(config, out_channels, name=f"convs.{len(config.atrous_rates) + 1}") self.convs.append(pool_layer) self.project = TFMobileViTConvLayer( config, - in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu", @@ -1006,17 +1008,17 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: self.dropout = tf.keras.layers.Dropout(config.aspp_dropout_prob) - def call(self, features: tf.Tensor) -> tf.Tensor: + def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor: # since the hidden states were transposed to have `(batch_size, channels, height, width)` # layout we transpose them back to have `(batch_size, height, width, channels)` layout. features = tf.transpose(features, perm=[0, 2, 3, 1]) pyramid = [] for conv in self.convs: - pyramid.append(conv(features)) + pyramid.append(conv(features, training=training)) pyramid = tf.concat(pyramid, axis=-1) - pooled_features = self.project(pyramid) - pooled_features = self.dropout(pooled_features) + pooled_features = self.project(pyramid, training=training) + pooled_features = self.dropout(pooled_features, training=training) return pooled_features @@ -1033,7 +1035,6 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: self.classifier = TFMobileViTConvLayer( config, - in_channels=config.aspp_out_channels, out_channels=config.num_labels, kernel_size=1, use_normalization=False, @@ -1042,10 +1043,10 @@ def __init__(self, config: MobileViTConfig, **kwargs) -> None: name="classifier", ) - def call(self, hidden_states: tf.Tensor) -> tf.Tensor: - features = self.aspp(hidden_states[-1]) - features = self.dropout(features) - features = self.classifier(features) + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + features = self.aspp(hidden_states[-1], training=training) + features = self.dropout(features, training=training) + features = self.classifier(features, training=training) return features @@ -1092,6 +1093,7 @@ def call( labels: Optional[tf.Tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + training: bool = False, ) -> Union[tuple, TFSemanticSegmenterOutput]: r""" labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): @@ -1129,11 +1131,12 @@ def call( pixel_values, output_hidden_states=True, # we need the intermediate hidden states return_dict=return_dict, + training=training, ) encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1] - logits = self.segmentation_head(encoder_hidden_states) + logits = self.segmentation_head(encoder_hidden_states, training=training) loss = None if labels is not None: diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py index 1e035ac7303c9..ac7c9a2644860 100644 --- a/tests/models/mobilevit/test_modeling_tf_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -256,7 +256,7 @@ def test_for_semantic_segmentation(self): @unittest.skipIf( not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, - reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + reason="TF does not support backprop for grouped convolutions on CPU.", ) def test_dataset_conversion(self): super().test_dataset_conversion() @@ -266,7 +266,7 @@ def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1): @unittest.skipIf( not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, - reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", + reason="TF does not support backprop for grouped convolutions on CPU.", ) def test_keras_fit(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() From 1c9b6f246e5a6d36aca968455f61fa9c28a2c6fe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 26 Aug 2022 07:48:21 +0530 Subject: [PATCH 09/20] address Amy's comments. --- src/transformers/modeling_tf_outputs.py | 31 ++++++++++++++++++ .../models/mobilevit/modeling_tf_mobilevit.py | 32 +++++++++---------- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index a1d3df074fe78..efb2412084a75 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -685,6 +685,37 @@ class TFSemanticSegmenterOutput(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None +@dataclass +class TFSemanticSegmenterOutputWithNoAttention(ModelOutput): + """ + Base class for outputs of semantic segmentation models that do not output attention scores. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): + Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for + the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[tf.Tensor] = None + logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + + @dataclass class TFImageClassifierOutput(ModelOutput): """ diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 924aa6a56a973..c86eb7a308dec 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -31,7 +31,7 @@ TFBaseModelOutput, TFBaseModelOutputWithPooling, TFImageClassifierOutputWithNoAttention, - TFSemanticSegmenterOutput, + TFSemanticSegmenterOutputWithNoAttention, ) from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs from ...tf_utils import shape_list, stable_softmax @@ -243,12 +243,12 @@ def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None: self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - batch_size = shape_list(x)[0] + batch_size = tf.shape(x)[0] x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: - batch_size = shape_list(hidden_states)[0] + batch_size = tf.shape(hidden_states)[0] key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -440,7 +440,10 @@ def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: patch_width, patch_height = self.patch_width, self.patch_height patch_area = tf.cast(patch_width * patch_height, "int32") - batch_size, orig_height, orig_width, channels = shape_list(features) + batch_size = tf.shape(features)[0] + orig_height = tf.shape(features)[1] + orig_width = tf.shape(features)[2] + channels = tf.shape(features)[3] new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") @@ -750,12 +753,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: `Dict[str, tf.Tensor]`: The dummy inputs. """ VISION_DUMMY_INPUTS = tf.random.uniform( - shape=( - 3, - self.config.num_channels, - self.config.image_size, - self.config.image_size, - ), + shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32, ) return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} @@ -1086,7 +1084,7 @@ def masked_loss(real, pred): @unpack_inputs @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC) def call( self, pixel_values: Optional[tf.Tensor] = None, @@ -1094,7 +1092,7 @@ def call( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, training: bool = False, - ) -> Union[tuple, TFSemanticSegmenterOutput]: + ) -> Union[tuple, TFSemanticSegmenterOutputWithNoAttention]: r""" labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*): Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., @@ -1156,13 +1154,13 @@ def call( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFSemanticSegmenterOutput( + return TFSemanticSegmenterOutputWithNoAttention( loss=loss, logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=None, ) - def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput: - # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions - return TFSemanticSegmenterOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=None) + def serving_output( + self, output: TFSemanticSegmenterOutputWithNoAttention + ) -> TFSemanticSegmenterOutputWithNoAttention: + return TFSemanticSegmenterOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states) From 82079a74268c8b633e61064058362a0e6e53294c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 29 Aug 2022 13:51:06 +0530 Subject: [PATCH 10/20] chore: remove from_pt argument. --- tests/models/mobilevit/test_modeling_tf_mobilevit.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py index ac7c9a2644860..d46ee895ed71f 100644 --- a/tests/models/mobilevit/test_modeling_tf_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -360,8 +360,7 @@ def test_loss_computation(self): @slow def test_model_from_pretrained(self): for model_name in TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - # `from_pt` will be removed. - model = TFMobileViTModel.from_pretrained(model_name, from_pt=True) + model = TFMobileViTModel.from_pretrained(model_name) self.assertIsNotNone(model) @@ -375,8 +374,7 @@ def prepare_img(): class TFMobileViTModelIntegrationTest(unittest.TestCase): @slow def test_inference_image_classification_head(self): - # `from_pt` will be removed - model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small", from_pt=True) + model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small") feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small") image = prepare_img() @@ -396,7 +394,7 @@ def test_inference_image_classification_head(self): @slow def test_inference_semantic_segmentation(self): # `from_pt` will be removed - model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small", from_pt=True) + model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small") feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small") From c0fbe3565d40555a41ecd0b2bebabc2c6f60bc75 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 29 Aug 2022 13:55:14 +0530 Subject: [PATCH 11/20] chore: add full-stop. --- docs/source/en/model_doc/mobilevit.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index fd35d89267e67..55b441d324c72 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -48,7 +48,7 @@ Tips: bandwidth can be constrained. -This model was contributed by [matthijs](https://huggingface.co/Matthijs). The TensorFlow version of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul) The original code and weights can be found [here](https://github.com/apple/ml-cvnets). +This model was contributed by [matthijs](https://huggingface.co/Matthijs). The TensorFlow version of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code and weights can be found [here](https://github.com/apple/ml-cvnets). ## MobileViTConfig From 32cfd30cee185a090a80a6604b850c639b04203b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 Aug 2022 12:29:33 +0530 Subject: [PATCH 12/20] fix: TFLite model conversion in the doc. --- docs/source/en/model_doc/mobilevit.mdx | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index 55b441d324c72..89512f3f2381b 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -35,16 +35,23 @@ Tips: from transformers import TFMobileViTForImageClassification import tensorflow as tf - model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small") + + model_ckpt = "apple/mobilevit-xx-small" + model = TFMobileViTForImageClassification.from_pretrained(model_ckpt) converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops. + ] tflite_model = converter.convert() - with open("mobilevit_xxs.tflite", "wb") as f: + tflite_filename = model_ckpt.split("/")[-1] + ".tflite" + with open(tflite_filename, "wb") as f: f.write(tflite_model) ``` - The resulting model will be just **about an MB** making it a perfect fit for mobile applications where resources and network + The resulting model will be just **about an MB** making it a good fit for mobile applications where resources and network bandwidth can be constrained. From b5593b9b63034bcae6686a16b3f65cb1e44f7908 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 16:59:55 +0530 Subject: [PATCH 13/20] Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index c86eb7a308dec..947bd240eba9a 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -404,10 +404,7 @@ def __init__( self.downsampling_layer = None self.conv_kxk = TFMobileViTConvLayer( - config, - out_channels=in_channels, - kernel_size=config.conv_kernel_size, - name="conv_kxk", + config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="conv_kxk" ) self.conv_1x1 = TFMobileViTConvLayer( From 436532082ef932d51bdff02558ea219ca132f972 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 17:00:06 +0530 Subject: [PATCH 14/20] Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 947bd240eba9a..fb1ed6a7f324f 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -427,10 +427,7 @@ def __init__( ) self.fusion = TFMobileViTConvLayer( - config, - out_channels=in_channels, - kernel_size=config.conv_kernel_size, - name="fusion", + config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion" ) def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: From 7c93be01fdb26cd1abbefea7cf72096fa1dbd967 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 17:01:04 +0530 Subject: [PATCH 15/20] Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index fb1ed6a7f324f..ee93798499c3b 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -442,11 +442,10 @@ def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") - interpolate = False - if new_width != orig_width or new_height != orig_height: + interpolate = (new_width != orig_width or new_height != orig_height) + if interpolate: # Note: Padding can be done, but then it needs to be handled in attention function. features = tf.image.resize(features, size=(new_height, new_width), method="bilinear") - interpolate = True # number of patches along width and height num_patch_width = new_width // patch_width From 06cb3682ab0a17c64c179ff4e5291b8d0c2c574f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 17:01:15 +0530 Subject: [PATCH 16/20] Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index ee93798499c3b..4f5d658e2633e 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -649,10 +649,7 @@ def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs if self.expand_output: self.conv_1x1_exp = TFMobileViTConvLayer( - config, - out_channels=config.neck_hidden_sizes[6], - kernel_size=1, - name="conv_1x1_exp", + config, out_channels=config.neck_hidden_sizes[6], kernel_size=1, name="conv_1x1_exp" ) self.pooler = tf.keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler") From 560d7ca56f4e43ade06e56aa47015f8bc2e9894a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 17:01:26 +0530 Subject: [PATCH 17/20] Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index 4f5d658e2633e..fb5632b78bc77 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -910,11 +910,7 @@ def call( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return TFImageClassifierOutputWithNoAttention( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - ) + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention: # hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions From 127a0f1f37fc22968e5e130e80cdb38306dc9d93 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 17:16:43 +0530 Subject: [PATCH 18/20] apply formatting. --- src/transformers/models/mobilevit/modeling_tf_mobilevit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py index fb5632b78bc77..c54bd6554dc35 100644 --- a/src/transformers/models/mobilevit/modeling_tf_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_tf_mobilevit.py @@ -442,7 +442,7 @@ def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]: new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32") new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32") - interpolate = (new_width != orig_width or new_height != orig_height) + interpolate = new_width != orig_width or new_height != orig_height if interpolate: # Note: Padding can be done, but then it needs to be handled in attention function. features = tf.image.resize(features, size=(new_height, new_width), method="bilinear") From 43ce94d622b6caaa6a16fc7f177c69f17408b12c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 18:04:22 +0530 Subject: [PATCH 19/20] chore: remove comments from the example block. --- docs/source/en/model_doc/mobilevit.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index 89512f3f2381b..3ca027d53172c 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -42,8 +42,8 @@ Tips: converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops. + tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS, ] tflite_model = converter.convert() tflite_filename = model_ckpt.split("/")[-1] + ".tflite" From 9b0037054e9a008cef0c4ad3ab8bfc7de369430a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 1 Sep 2022 18:50:23 +0530 Subject: [PATCH 20/20] remove identation in the example. --- docs/source/en/model_doc/mobilevit.mdx | 38 +++++++++++++------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source/en/model_doc/mobilevit.mdx b/docs/source/en/model_doc/mobilevit.mdx index 3ca027d53172c..5725bd5ce5835 100644 --- a/docs/source/en/model_doc/mobilevit.mdx +++ b/docs/source/en/model_doc/mobilevit.mdx @@ -31,25 +31,25 @@ Tips: You can use the following code to convert a MobileViT checkpoint (be it image classification or semantic segmentation) to generate a TensorFlow Lite model: - ```py - from transformers import TFMobileViTForImageClassification - import tensorflow as tf - - - model_ckpt = "apple/mobilevit-xx-small" - model = TFMobileViTForImageClassification.from_pretrained(model_ckpt) - - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS, - ] - tflite_model = converter.convert() - tflite_filename = model_ckpt.split("/")[-1] + ".tflite" - with open(tflite_filename, "wb") as f: - f.write(tflite_model) - ``` +```py +from transformers import TFMobileViTForImageClassification +import tensorflow as tf + + +model_ckpt = "apple/mobilevit-xx-small" +model = TFMobileViTForImageClassification.from_pretrained(model_ckpt) + +converter = tf.lite.TFLiteConverter.from_keras_model(model) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, + tf.lite.OpsSet.SELECT_TF_OPS, +] +tflite_model = converter.convert() +tflite_filename = model_ckpt.split("/")[-1] + ".tflite" +with open(tflite_filename, "wb") as f: + f.write(tflite_model) +``` The resulting model will be just **about an MB** making it a good fit for mobile applications where resources and network bandwidth can be constrained.