Skip to content

Commit

Permalink
TensorFlow MobileViT (huggingface#18555)
Browse files Browse the repository at this point in the history
* initial implementation.

* add: working model till image classification.

* add: initial implementation that passes intg tests.

Co-authored-by: Amy <aeroberts4444@gmail.com>

* chore: formatting.

* add: tests (still breaking because of config mismatch).

Coo-authored-by: Yih <2521628+ydshieh@users.noreply.github.com>

* add: corrected tests and remaning changes.

* fix code style and repo consistency.

* address PR comments.

* address Amy's comments.

* chore: remove from_pt argument.

* chore: add full-stop.

* fix: TFLite model conversion in the doc.

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply formatting.

* chore: remove comments from the example block.

* remove identation in the example.

Co-authored-by: Amy <aeroberts4444@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored and oneraghavan committed Sep 26, 2022
1 parent 36fead3 commit 51e2bfa
Show file tree
Hide file tree
Showing 10 changed files with 1,740 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Expand Up @@ -259,7 +259,7 @@ Flax), PyTorch, and/or TensorFlow.
| mBART | | | | | |
| Megatron-BERT | | | | | |
| MobileBERT | | | | | |
| MobileViT | | | | | |
| MobileViT | | | | | |
| MPNet | | | | | |
| MT5 | | | | | |
| MVP | | | | | |
Expand Down
47 changes: 45 additions & 2 deletions docs/source/en/model_doc/mobilevit.mdx
Expand Up @@ -22,12 +22,40 @@ The abstract from the paper is the following:

Tips:

- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map.
- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map. You can follow [this tutorial](https://keras.io/examples/vision/mobilevit) for a lightweight introduction.
- One can use [`MobileViTFeatureExtractor`] to prepare images for the model. Note that if you do your own preprocessing, the pretrained checkpoints expect images to be in BGR pixel order (not RGB).
- The available image classification checkpoints are pre-trained on [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k) (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).
- The segmentation model uses a [DeepLabV3](https://arxiv.org/abs/1706.05587) head. The available semantic segmentation checkpoints are pre-trained on [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/).
- As the name suggests MobileViT was desgined to be performant and efficient on mobile phones. The TensorFlow versions of the MobileViT models are fully compatible with [TensorFlow Lite](https://www.tensorflow.org/lite).

This model was contributed by [matthijs](https://huggingface.co/Matthijs). The original code and weights can be found [here](https://github.com/apple/ml-cvnets).
You can use the following code to convert a MobileViT checkpoint (be it image classification or semantic segmentation) to generate a
TensorFlow Lite model:

```py
from transformers import TFMobileViTForImageClassification
import tensorflow as tf


model_ckpt = "apple/mobilevit-xx-small"
model = TFMobileViTForImageClassification.from_pretrained(model_ckpt)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
tflite_filename = model_ckpt.split("/")[-1] + ".tflite"
with open(tflite_filename, "wb") as f:
f.write(tflite_model)
```

The resulting model will be just **about an MB** making it a good fit for mobile applications where resources and network
bandwidth can be constrained.


This model was contributed by [matthijs](https://huggingface.co/Matthijs). The TensorFlow version of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code and weights can be found [here](https://github.com/apple/ml-cvnets).


## MobileViTConfig
Expand All @@ -53,3 +81,18 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The o

[[autodoc]] MobileViTForSemanticSegmentation
- forward

## TFMobileViTModel

[[autodoc]] TFMobileViTModel
- call

## TFMobileViTForImageClassification

[[autodoc]] TFMobileViTForImageClassification
- call

## TFMobileViTForSemanticSegmentation

[[autodoc]] TFMobileViTForSemanticSegmentation
- call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Expand Up @@ -2398,6 +2398,15 @@
"TFMobileBertPreTrainedModel",
]
)
_import_structure["models.mobilevit"].extend(
[
"TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileViTPreTrainedModel",
"TFMobileViTModel",
"TFMobileViTForImageClassification",
"TFMobileViTForSemanticSegmentation",
]
)
_import_structure["models.mpnet"].extend(
[
"TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4847,6 +4856,7 @@
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
Expand All @@ -4857,6 +4867,10 @@
TFMobileBertMainLayer,
TFMobileBertModel,
TFMobileBertPreTrainedModel,
TFMobileViTForImageClassification,
TFMobileViTForSemanticSegmentation,
TFMobileViTModel,
TFMobileViTPreTrainedModel,
)
from .models.mpnet import (
TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/modeling_tf_outputs.py
Expand Up @@ -685,6 +685,37 @@ class TFSemanticSegmenterOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of semantic segmentation models that do not output attention scores.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFImageClassifierOutput(ModelOutput):
"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Expand Up @@ -59,6 +59,7 @@
("marian", "TFMarianModel"),
("mbart", "TFMBartModel"),
("mobilebert", "TFMobileBertModel"),
("mobilevit", "TFMobileViTModel"),
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
Expand Down Expand Up @@ -182,6 +183,7 @@
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
Expand All @@ -194,6 +196,7 @@
[
# Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("mobilevit", "TFMobileViTForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
]
)
Expand Down
43 changes: 42 additions & 1 deletion src/transformers/models/mobilevit/__init__.py
Expand Up @@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand Down Expand Up @@ -46,6 +52,19 @@
"MobileViTPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_mobilevit"] = [
"TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileViTForImageClassification",
"TFMobileViTForSemanticSegmentation",
"TFMobileViTModel",
"TFMobileViTPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig, MobileViTOnnxConfig
Expand All @@ -72,6 +91,28 @@
MobileViTPreTrainedModel,
)

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_mobilevit import MobileViTFeatureExtractor

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_mobilevit import (
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileViTForImageClassification,
TFMobileViTForSemanticSegmentation,
TFMobileViTModel,
TFMobileViTPreTrainedModel,
)


else:
import sys
Expand Down

0 comments on commit 51e2bfa

Please sign in to comment.