From c1489f500e06dea5dbb7370a00f13f5bc0e6d94a Mon Sep 17 00:00:00 2001 From: Amanpreet Singh Date: Thu, 7 Apr 2022 08:37:08 -0700 Subject: [PATCH] [WIP] Add FLAVA model This PR aims to add [FLAVA](ihttps://arxiv.org/abs/2112.04482) model to the transformers repo. Following checklist delineates the list of things to be done for this PR to be complete: [x] Flava init [x] Flava base models [x] Flava layers [x] Flava Configs [x] Flava encoders [x] Flava pretraining models [ ] Flava classification/retrieval models (in progress) [x] Documentation updates (in progress) [x] Imports updates (in progress) [x] Argstring updates [x] Flava pretrained checkpoints (in progress) [ ] Flava tests [x] Flava processors (in progress) [x] Sanity check [x] Lint --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/flava.mdx | 93 + src/transformers/__init__.py | 43 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/flava/__init__.py | 77 + .../models/flava/configuration_flava.py | 599 +++++ .../flava/convert_dalle_to_flava_codebook.py | 91 + .../convert_flava_original_pytorch_to_hf.py | 92 + .../models/flava/feature_extraction_flava.py | 366 +++ .../models/flava/modeling_flava.py | 2044 +++++++++++++++++ .../models/flava/processing_flava.py | 148 ++ tests/flava/__init__.py | 0 tests/flava/test_feature_extraction_flava.py | 414 ++++ tests/flava/test_modeling_flava.py | 824 +++++++ tests/flava/test_processor_flava.py | 204 ++ 20 files changed, 5011 insertions(+) create mode 100644 docs/source/en/model_doc/flava.mdx create mode 100644 src/transformers/models/flava/__init__.py create mode 100644 src/transformers/models/flava/configuration_flava.py create mode 100644 src/transformers/models/flava/convert_dalle_to_flava_codebook.py create mode 100644 src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py create mode 100644 src/transformers/models/flava/feature_extraction_flava.py create mode 100644 src/transformers/models/flava/modeling_flava.py create mode 100644 src/transformers/models/flava/processing_flava.py create mode 100644 tests/flava/__init__.py create mode 100644 tests/flava/test_feature_extraction_flava.py create mode 100644 tests/flava/test_modeling_flava.py create mode 100644 tests/flava/test_processor_flava.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 69717477e1f230..90ce8c8d4f861c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -214,6 +214,8 @@ title: Encoder Decoder Models - local: model_doc/flaubert title: FlauBERT + - local: model_doc/flava + title: FLAVA - local: model_doc/fnet title: FNet - local: model_doc/fsmt diff --git a/docs/source/en/model_doc/flava.mdx b/docs/source/en/model_doc/flava.mdx new file mode 100644 index 00000000000000..2c9478a77dc1e7 --- /dev/null +++ b/docs/source/en/model_doc/flava.mdx @@ -0,0 +1,93 @@ + + +# FLAVA + +## Overview + +The FLAVA model was proposed in [FLAVA: A Foundational Language And Vision Alignment Model +](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela and is accepted at CVPR 2022. + +The paper aims at creating a single unified foundation model which can work across vision, language +as well as vision-and-language multimodal tasks. + +The abstract from the paper is the following: + +State-of-the-art vision and vision-and-language models rely on large-scale visio-linguistic pretraining for obtaining good performance on a variety +of downstream tasks. Generally, such models are often either cross-modal (contrastive) or multi-modal +(with earlier fusion) but not both; and they often only target specific modalities or tasks. A promising +direction would be to use a single holistic universal model, as a "foundation", that targets all modalities +at once -- a true vision and language foundation model should be good at vision tasks, language tasks, and +cross- and multi-modal vision and language tasks. We introduce FLAVA as such a model and demonstrate +impressive performance on a wide range of 35 tasks spanning these target modalities. + + + + +This model was contributed by [aps](https://huggingface.co/aps). + + + +## FLAVAConfig + +[[autodoc]] FLAVAConfig + - from_configs + +## FLAVATextConfig + +[[autodoc]] FLAVATextConfig + +## FLAVAImageConfig + +[[autodoc]] FLAVAImageConfig + +## FLAVAMultimodalConfig + +[[autodoc]] FLAVAMultimodalConfig + +## FLAVACodebookConfig + +[[autodoc]] FLAVACodebookConfig + +## FLAVAForPretraining + +[[autodoc]] FLAVAForPretraining + - forward + +## FLAVAModel + +[[autodoc]] FLAVAModel + - forward + - get_text_features + - get_image_features + +## FLAVACodebook + +[[autodoc]] FLAVACodebook + - forward + - get_codebook_indices + - get_codebook_probs + +## FLAVATextModel + +[[autodoc]] FLAVATextModel + - forward + +## FLAVAImageModel + +[[autodoc]] FLAVAImageModel + - forward + +## FLAVAMultimodalModel + +[[autodoc]] FLAVAMultimodalModel + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 841b830c3b1db7..4b03299b61f7e1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -191,6 +191,17 @@ "models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"], "models.encoder_decoder": ["EncoderDecoderConfig"], "models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"], + "models.flava": [ + "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "FLAVACodebookConfig", + "FLAVACodebookFeatureExtractor", + "FLAVAConfig", + "FLAVAFeatureExtractor", + "FLAVAImageConfig", + "FLAVAMultimodalConfig", + "FLAVAProcessor", + "FLAVATextConfig", + ], "models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig", "FNetTokenizer"], "models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"], "models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"], @@ -986,6 +997,19 @@ "FlaubertWithLMHeadModel", ] ) + _import_structure["models.flava"].extend( + [ + "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "FLAVACodebook", + "FLAVAForPretraining", + "FLAVAImageModel", + "FLAVALayer", + "FLAVAModel", + "FLAVAMultimodalModel", + "FLAVAPreTrainedModel", + "FLAVATextModel", + ] + ) _import_structure["models.fnet"].extend( [ "FNET_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2565,6 +2589,17 @@ from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer from .models.encoder_decoder import EncoderDecoderConfig from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer + from .models.flava import ( + FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, + FLAVACodebookConfig, + FLAVACodebookFeatureExtractor, + FLAVAConfig, + FLAVAFeatureExtractor, + FLAVAImageConfig, + FLAVAMultimodalConfig, + FLAVAProcessor, + FLAVATextConfig, + ) from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig, FNetTokenizer from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer @@ -3238,6 +3273,14 @@ FlaubertModel, FlaubertWithLMHeadModel, ) + from .models.flava import ( + FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + FLAVAForPretraining, + FLAVAImageModel, + FLAVAModel, + FLAVAPreTrainedModel, + FLAVATextModel, + ) from .models.fnet import ( FNET_PRETRAINED_MODEL_ARCHIVE_LIST, FNetForMaskedLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7045f18c556154..1432fd61cc545c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -54,6 +54,7 @@ electra, encoder_decoder, flaubert, + flava, fnet, fsmt, funnel, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1899c3c249fafa..9eb08492768218 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -62,6 +62,7 @@ ("canine", "CanineConfig"), ("roformer", "RoFormerConfig"), ("clip", "CLIPConfig"), + ("flava", "FLAVAConfig"), ("bigbird_pegasus", "BigBirdPegasusConfig"), ("deit", "DeiTConfig"), ("luke", "LukeConfig"), @@ -164,6 +165,7 @@ ("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -258,6 +260,7 @@ ("canine", "Canine"), ("roformer", "RoFormer"), ("clip", "CLIP"), + ("flava", "flava"), ("bigbird_pegasus", "BigBirdPegasus"), ("deit", "DeiT"), ("luke", "LUKE"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index dc83fc133fad97..3221e1b286430e 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -47,6 +47,7 @@ ("detr", "DetrFeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("clip", "CLIPFeatureExtractor"), + ("flava", "FLAVAFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), ("swin", "ViTFeatureExtractor"), ("vit_mae", "ViTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b0cfb47672491a..c21dfd1ed974ec 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -59,6 +59,7 @@ ("canine", "CanineModel"), ("roformer", "RoFormerModel"), ("clip", "CLIPModel"), + ("flava", "FLAVAModel"), ("bigbird_pegasus", "BigBirdPegasusModel"), ("deit", "DeiTModel"), ("luke", "LukeModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index b51ef9ef312e10..23ec7c400269ee 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -38,6 +38,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("clip", "CLIPProcessor"), + ("flava", "FLAVAProcessor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutxlm", "LayoutXLMProcessor"), ("speech_to_text", "Speech2TextProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index a604841340967d..f1ad8bb4b407e6 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -218,6 +218,13 @@ "CLIPTokenizerFast" if is_tokenizers_available() else None, ), ), + # ( + # "flava", + # ( + # "CLIPTokenizer", + # "CLIPTokenizerFast" if is_tokenizers_available() else None, + # ), + # ), ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), ( "perceiver", diff --git a/src/transformers/models/flava/__init__.py b/src/transformers/models/flava/__init__.py new file mode 100644 index 00000000000000..7027ae3798e82f --- /dev/null +++ b/src/transformers/models/flava/__init__.py @@ -0,0 +1,77 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_flava": [ + "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "FLAVACodebookConfig", + "FLAVAConfig", + "FLAVAImageConfig", + "FLAVAMultimodalConfig", + "FLAVATextConfig", + ], +} + +if is_vision_available(): + _import_structure["feature_extraction_flava"] = ["FLAVACodebookFeatureExtractor", "FLAVAFeatureExtractor"] + _import_structure["processing_flava"] = ["FLAVAProcessor"] + +if is_torch_available(): + _import_structure["modeling_flava"] = [ + "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", + "FLAVACodebook", + "FLAVAForPretraining", + "FLAVAImageModel", + "FLAVAModel", + "FLAVAMultimodalModel", + "FLAVAPreTrainedModel", + "FLAVATextModel", + ] + +if TYPE_CHECKING: + from .configuration_flava import ( + FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, + FLAVAConfig, + FLAVAImageConfig, + FLAVAMultimodalConfig, + FLAVATextConfig, + ) + + if is_vision_available(): + from .feature_extraction_flava import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor + from .processing_flava import FLAVAProcessor + + if is_torch_available(): + from .modeling_flava import ( + FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + FLAVACodebook, + FLAVAImageModel, + FLAVAModel, + FLAVAMultimodalModel, + FLAVAPreTrainedModel, + FLAVATextModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/flava/configuration_flava.py b/src/transformers/models/flava/configuration_flava.py new file mode 100644 index 00000000000000..179f1d4421beb7 --- /dev/null +++ b/src/transformers/models/flava/configuration_flava.py @@ -0,0 +1,599 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" FLAVA model configuration""" + +import copy +import os +from typing import Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "flava-full": "https://huggingface.co/aps/flava-full/resolve/main/config.json", +} + + +class FLAVAImageConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FLAVAImageModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FLAVA + [full](https://huggingface.co/aps/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to `224`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to `16`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to `3`): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + mask_token (`bool`, *optional*, defaults to True): + Whether to use a mask token or not. Used in MIM loss. + vocab_size (`int`, *optional*, defaults to 8192): + Vocabulary size of the [`FLAVACodebook`] used in conjunction with [`FLAVAImageModel`] for MIM. + + Example: + + ```python + >>> from transformers import FLAVAImageModel, FLAVAImageConfig + + >>> # Initializing a FLAVAImageModel with style configuration + >>> configuration = FLAVAImageConfig() + + >>> # Initializing a FLAVAImageModel model from the style configuration + >>> model = FLAVAImageModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "flava_image_model" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + mask_token=True, + vocab_size=8192, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.mask_token = mask_token + self.vocab_size = vocab_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the image config dict if we are loading from FLAVAConfig + if config_dict.get("model_type") == "flava": + config_dict = config_dict["image_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class FLAVATextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FLAVATextModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FLAVA + [full](https://huggingface.co/aps/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FLAVATextModel`]. + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`FLAVATextModel`]. Note that even though + text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is + used similar to RoBERTa. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to `224`): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to `16`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to `3`): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import FLAVATextModel, FLAVATextConfig + + >>> # Initializing a FLAVATextModel with style configuration + >>> configuration = FLAVATextConfig() + + >>> # Initializing a FLAVATextConfig from the style configuration + >>> model = FLAVATextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "flava_text_model" + + def __init__( + self, + vocab_size=30522, + type_vocab_size=2, + max_position_embeddings=512, + position_embedding_type="absolute", + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + qkv_bias=True, + **kwargs + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.type_vocab_size = type_vocab_size + self.max_position_embeddings = max_position_embeddings + self.position_embedding_type = position_embedding_type + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.pad_token_id = pad_token_id + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from FLAVAConfig + if config_dict.get("model_type") == "flava": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class FLAVAMultimodalConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FLAVAMultimodalModel`]. It is used to instantiate + an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FLAVA + [full](https://huggingface.co/aps/flava-full) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + use_cls_token (`bool`, *optional*, defaults to `True`): + Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model. + + + Example: + + ```python + >>> from transformers import FLAVAMultimodalModel, FLAVAMultimodalConfig + + >>> # Initializing a FLAVAMultimodalModel with style configuration + >>> configuration = FLAVAMultimodalConfig() + + >>> # Initializing a FLAVAMultimodalModel model from the style configuration + >>> model = FLAVAMultimodalModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "flava_multimodal_model" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=6, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + qkv_bias=True, + use_cls_token=True, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.use_cls_token = use_cls_token + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the image config dict if we are loading from FLAVAConfig + if config_dict.get("model_type") == "flava": + config_dict = config_dict["multimodal_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class FLAVACodebookConfig(PretrainedConfig): + model_type = "flava_codebook" + + r""" + [`FLAVACodebookConfig`] is the configuration class to store the configuration of a [`FLAVACodebook`]. It is used to + instantiate an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the FLAVA + [codebook](https://huggingface.co/aps/flava-codebook) architecture + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_groups (`int`, defaults to 4): + Number of groups to be created. This parameter as of now doesn't affect the model and is used for some + internal calculation and estimations. + input_channels (`int`, defaults to 3): + Number of channels in the image to be passed. + num_blocks_per_group (`int`, defaults to 2): + Number of conv-based blocks per group. + hidden_size (`int`, defaults to 256): + Size of hidden dim for the blocks. + vocab_size (`int`, defaults to 8192): + Size of the output vocabulary for the codebook. + freeze (`bool`, defaults to True): + Whether to freeze the weights of the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + + ```python + >>> from transformers import FLAVACodebook, FLAVACodebookConfig >>> # Initializing a FLAVACodebook with style + configuration >>> configuration = FLAVACodebookConfig() >>> # Initializing a FLAVACodebook model from the style + configuration >>> model = FLAVACodebook(configuration) >>> # Accessing the model configuration >>> configuration = + model.config""" + + def __init__( + self, + num_groups=4, + input_channels=3, + num_blocks_per_group=2, + hidden_size=256, + vocab_size=8192, + freeze=True, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.num_groups = num_groups + self.input_channels = input_channels + self.num_blocks_per_group = num_blocks_per_group + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.freeze = freeze + self.initializer_range = initializer_range + + +class FLAVAConfig(PretrainedConfig): + r""" + [`FLAVAConfig`] is the configuration class to store the configuration of a [`FLAVAModel`]. It is used to + instantiate FLAVA model according to the specified arguments, defining the text model, image model and multimodal + model configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config_dict (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FLAVATextConfig`]. + image_config_dict (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FLAVAImageConfig`]. + multimodal_config_dict (`dict`, *optional*): + Dictionary of configuration options used to initialize [`FLAVAMultimodalConfig`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + projection_dim (`int`, *optional*, defaults to 512): + Dimentionality of text and image projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* paramter. Default is used as per the original FLAVA/CLIP + implementation. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + ce_ignore_index (`int`, *optional*, defaults to -100): + Cross entropy index to ignore. + mim_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MIM unimodal loss + mlm_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MLM unimodal loss + global_contrastive_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to global contrastive cross-alignment loss. + itm_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to image-text matching multimodal loss. + mmm_image_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MMM loss's image part. + mmm_text_weight (`float`, *optional*, defaults to 1.0): + Weight to be assigned to MMM loss's text part. + global_backprop_contrastive (`bool`, *optional*, defaults to True): + Whether to use global backpropgation through all workers in contrastive loss. + skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to True): + Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses. + + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import FLAVAModel, FLAVAForPretraining, FLAVAConfig + + >>> # Initializing a FLAVAConfig with style configuration + >>> configuration = FLAVAConfig() + + >>> # Initializing a FLAVAModel and FLAVAForPretraining model from the style configuration + >>> model = FLAVAModel(configuration) + >>> model_pre = FLAVAForPretraining(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + >>> configuration_pre = model_pre.config + ``` + """ + + model_type = "flava" + is_composition = True + + def __init__( + self, + text_config_dict=None, + image_config_dict=None, + multimodal_config_dict=None, + hidden_size=768, + layer_norm_eps=1e-12, + projection_dim=768, + logit_scale_init_value=2.6592, + initializer_range=0.02, + ce_ignore_index=-100, + mim_weight=1.0, + mlm_weight=1.0, + global_contrastive_weight=1.0, + itm_weight=1.0, + mmm_image_weight=1.0, + mmm_text_weight=1.0, + global_backprop_contrastive=True, + skip_unmasked_multimodal_encoder=True, + **kwargs + ): + super().__init__( + text_config_dict=text_config_dict, + image_config_dict=image_config_dict, + multimodal_config_dict=multimodal_config_dict, + **kwargs, + ) + + if text_config_dict is None: + text_config_dict = {} + logger.info("text_config_dict is None. Initializing the FLAVATextConfig with default values.") + + if image_config_dict is None: + image_config_dict = {} + logger.info("image_config_dict is None. initializing the FLAVAImageConfig with default values.") + + if multimodal_config_dict is None: + multimodal_config_dict = {} + logger.info("multimodal_config_dict is None. initializing the FLAVAImageConfig with default values.") + + self.text_config = FLAVATextConfig(**text_config_dict) + self.image_config = FLAVAImageConfig(**image_config_dict) + self.multimodal_config = FLAVAMultimodalConfig(**multimodal_config_dict) + self.projection_dim = projection_dim + + self.hidden_size = hidden_size + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.logit_scale_init_value = logit_scale_init_value + self.initializer_factor = 1.0 + self.ce_ignore_index = ce_ignore_index + self.mim_weight = mim_weight + self.mlm_weight = mlm_weight + self.global_contrastive_weight = global_contrastive_weight + self.itm_weight = itm_weight + self.mmm_image_weight = mmm_image_weight + self.mmm_text_weight = mmm_text_weight + self.global_backprop_contrastive = global_backprop_contrastive + self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder + + @classmethod + def from_configs( + cls, + text_config: FLAVATextConfig, + image_config: FLAVAImageConfig, + multimodal_config: FLAVAMultimodalConfig, + **kwargs + ): + r""" + Instantiate a [`FLAVAConfig`] (or a derived class) from flava text model configuration, flava image model + configuration and flava multimodal model configuration. + + Returns: + [`FLAVAConfig`]: An instance of a configuration object + """ + + return cls( + text_config_dict=text_config.to_dict(), + image_config_dict=image_config.to_dict(), + multimodal_config_dict=multimodal_config.to_dict(), + **kwargs, + ) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["image_config"] = self.image_config.to_dict() + output["multimodal_config"] = self.multimodal_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/flava/convert_dalle_to_flava_codebook.py b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py new file mode 100644 index 00000000000000..59f54a21eb6150 --- /dev/null +++ b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from transformers import FLAVACodebook, FLAVACodebookConfig + + +def rreplace(s, old, new, occurrence): + li = s.rsplit(old, occurrence) + return new.join(li) + + +def count_parameters(state_dict): + # encoder.embeddings are double copied in original FLAVA + return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items()) + + +def upgrade_state_dict(state_dict): + upgrade = {} + + for key, value in state_dict.items(): + if key.endswith(".w"): + key = rreplace(key, ".w", ".weight", 1) + if key.endswith(".b"): + key = rreplace(key, ".b", ".bias", 1) + + upgrade[key] = value.float() + + return upgrade + + +@torch.no_grad() +def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + from dall_e import Encoder + + encoder = Encoder() + if os.path.exists(checkpoint_path): + ckpt = torch.load(checkpoint_path) + else: + ckpt = torch.hub.load_state_dict_from_url(checkpoint_path) + + if isinstance(ckpt, Encoder): + ckpt = ckpt.state_dict() + encoder.load_state_dict(ckpt) + + if config_path is not None: + config = FLAVACodebookConfig.from_pretrained(config_path) + else: + config = FLAVACodebookConfig() + + hf_model = FLAVACodebook(config).eval() + state_dict = encoder.state_dict() + + hf_state_dict = upgrade_state_dict(state_dict) + hf_model.load_state_dict(hf_state_dict) + hf_state_dict = hf_model.state_dict() + hf_count = count_parameters(hf_state_dict) + state_dict_count = count_parameters(state_dict) + + assert torch.allclose(hf_count, state_dict_count, atol=1e-3) + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + args = parser.parse_args() + + convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) diff --git a/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py new file mode 100644 index 00000000000000..c1630941ea9765 --- /dev/null +++ b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import torch + +from transformers import FLAVAConfig, FLAVAForPretraining + + +def count_parameters(state_dict): + # encoder.embeddings are double copied in original FLAVA + return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items()) + + +def upgrade_state_dict(state_dict): + upgrade = {} + + for key, value in state_dict.items(): + if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key: + continue + + key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head") + key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head") + key = key.replace("heads.cmd.itm_head.cls", "itm_head") + key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler") + key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale") + key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head") + key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head") + key = key.replace("mm_text_projection", "flava.text_to_mm_projection") + key = key.replace("mm_image_projection", "flava.image_to_mm_projection") + key = key.replace("image_encoder.module", "flava.image_model") + key = key.replace("text_encoder.module", "flava.text_model") + key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token") + key = key.replace("mm_encoder.module", "flava.multimodal_model") + key = key.replace("text_projection", "flava.text_projection") + key = key.replace("image_projection", "flava.image_projection") + + upgrade[key] = value.float() + + return upgrade + + +@torch.no_grad() +def convert_flava_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = FLAVAConfig.from_pretrained(config_path) + else: + config = FLAVAConfig() + + hf_model = FLAVAForPretraining(config).eval() + + if os.path.exists(checkpoint_path): + state_dict = torch.load(checkpoint_path, map_location="cpu") + else: + state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu") + + hf_state_dict = upgrade_state_dict(state_dict) + hf_model.load_state_dict(hf_state_dict) + hf_state_dict = hf_model.state_dict() + hf_count = count_parameters(hf_state_dict) + state_dict_count = count_parameters(state_dict) + + assert torch.allclose(hf_count, state_dict_count, atol=1e-3) + + hf_model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + args = parser.parse_args() + + convert_flava_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py new file mode 100644 index 00000000000000..3207b309694324 --- /dev/null +++ b/src/transformers/models/flava/feature_extraction_flava.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for FLAVA.""" + +import math +import random +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +# These values are taken from CLIP +FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073] +FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711] +FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0] +FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0] +LOGIT_LAPLACE_EPS: float = 0.1 + + +# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py +class MaskingGenerator: + def __init__( + self, + input_size: Union[int, Tuple[int, int]], + total_mask_patches: int = 75, + mask_group_max_patches: int = None, + mask_group_min_patches: Optional[int] = 16, + mask_group_min_aspect_ratio: float = 0.3, + mask_group_max_aspect_ratio: Optional[float] = None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.total_mask_patches = total_mask_patches + + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches + + mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio + self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio)) + + def __repr__(self): + repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.mask_group_min_patches, + self.mask_group_max_patches, + self.total_mask_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _attempt in range(10): + target_area = random.uniform(self.mask_group_min_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self): + mask = np.zeros(shape=self.get_shape(), dtype=int) + mask_count = 0 + while mask_count < self.total_mask_patches: + max_mask_patches = self.total_mask_patches - mask_count + max_mask_patches = min(max_mask_patches, self.mask_group_max_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask + + +class FLAVAFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a FLAVA feature extractor. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 224): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. + crop_size (`int`, *optional*, defaults to 224): + Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. + image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=224, + resample=Image.BICUBIC, + do_center_crop=True, + crop_size=224, + do_normalize=True, + image_mean=None, + image_std=None, + input_size_patches: int = 14, + total_mask_patches: int = 75, + mask_group_min_patches: Optional[int] = 16, + mask_group_max_patches: int = None, + mask_group_min_aspect_ratio: float = 0.3, + mask_group_max_aspect_ratio: Optional[float] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN + self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD + self.input_size_patches = input_size_patches + self.total_mask_patches = total_mask_patches + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_max_patches = mask_group_max_patches + self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio + self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio + + @property + @lru_cache() + def masking_generator(self): + return MaskingGenerator( + input_size=self.input_size_patches, + total_mask_patches=self.total_mask_patches, + mask_group_min_patches=self.mask_group_min_patches, + mask_group_max_patches=self.mask_group_max_patches, + mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio, + mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio, + ) + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_masks: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_masks (`bool`, *optional*, defaults to None): + If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version. + + + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + if isinstance(images, (list, tuple)) and len(images) != 0: + self._ensure_format_supported(images[0]) + else: + self._ensure_format_supported(images) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize and self.size is not None and self.resample is not None: + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] + if self.do_center_crop and self.crop_size is not None: + images = [self.center_crop(image, self.crop_size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + # return as BatchFeature + data = {"pixel_values": images} + + if return_masks: + masks = [self.masking_generator() for _ in images] + data["bool_masked_pos"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs + + +class FLAVACodebookFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + def __init__( + self, + do_resize=True, + size=112, + resample=Image.LANCZOS, + do_center_crop=True, + crop_size=112, + do_map_pixels=True, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs, + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.do_map_pixels = do_map_pixels + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else FLAVA_CODEBOOK_MEAN + self.image_std = image_std if image_std is not None else FLAVA_CODEBOOK_STD + + def map_pixels(self, x): + return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + if isinstance(images, (list, tuple)) and len(images) != 0: + self._ensure_format_supported(images[0]) + else: + self._ensure_format_supported(images) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize and self.size is not None and self.resample is not None: + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] + if self.do_center_crop and self.crop_size is not None: + images = [self.center_crop(image, self.crop_size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + if self.do_map_pixels: + images = [self.map_pixels(image) for image in images] + data = {"pixel_values": images} + + # return as BatchFeature + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py new file mode 100644 index 00000000000000..b500bba58c972d --- /dev/null +++ b/src/transformers/models/flava/modeling_flava.py @@ -0,0 +1,2044 @@ +# coding=utf-8 +# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch FLAVA model.""" + +import collections +import math +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn + +from transformers.utils.doc import add_code_sample_docstrings + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_flava import ( + FLAVACodebookConfig, + FLAVAConfig, + FLAVAImageConfig, + FLAVAMultimodalConfig, + FLAVATextConfig, +) + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "aps/flava-full" + +# Codebook docstring +_CHECKPOINT_FOR_CODEBOOK_DOC = "aps/flava-codebook" +_FEAT_EXTRACTOR_FOR_DOC = "FLAVAFeatureExtractor" +_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FLAVAImageConfig" +_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FLAVATextConfig" +_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FLAVAMultimodalConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" +_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768] + +FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "aps/flava-full", + # See all flava models at https://huggingface.co/models?filter=flava +] +LOGIT_SCALE_CLAMP_MIN = 0 +LOGIT_SCALE_CLAMP_MAX = 4.6052 + +FLAVAPossibleConfigs = Union[FLAVATextConfig, FLAVAImageConfig, FLAVAMultimodalConfig] + + +def _build_codebook_conv2d(in_size: int, out_size: int, kernel_size: int): + return nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) + + +@dataclass +class FLAVAModelOutput(ModelOutput): + """ + Output from FLAVAModel containing embeddings and outputs from individual encoders. + + Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a + transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and + `text_projection` layers on `image_embeddings` and `text_embeddings` respectively. + + Args: + image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. + image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FLAVAImageModel`]. + text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FLAVATextModel`]. + text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FLAVATextModel`]. + multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. + multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The output of the [`FLAVAMultimodalModel`]. + """ + + image_embeddings: Optional[torch.FloatTensor] = None + image_output: Optional[BaseModelOutputWithPooling] = None + text_embeddings: Optional[torch.FloatTensor] = None + text_output: Optional[BaseModelOutputWithPooling] = None + multimodal_embeddings: Optional[torch.FloatTensor] = None + multimodal_output: Optional[BaseModelOutputWithPooling] = None + + +@dataclass +class FLAVALosses(ModelOutput): + """Class representing pretraining losses from FLAVA model + + Args: + mim(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.: + Masked Image Modeling loss as used in BeIT calculated only for unimodal image data. + mlm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.: + Masked Language Modeling loss as used in BERT calculated only for unimodal text data. + itm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.: + Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on + masked pairs in FLAVA. + global_contrastive(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.: + Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text + data. This is calculated on unmasked images and texts. + mmm_image(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.: + Masked Multimodal Modeling loss's image component calculated on paired image-text data. + mmm_text(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.: + Masked Multimodal Modeling loss's text component calculated on paired image-text data. + """ + + mim: Optional[torch.FloatTensor] = None + mlm: Optional[torch.FloatTensor] = None + itm: Optional[torch.FloatTensor] = None + global_contrastive: Optional[torch.FloatTensor] = None + mmm_image: Optional[torch.FloatTensor] = None + mmm_text: Optional[torch.FloatTensor] = None + + def all_none(self) -> bool: + all_none = True + for k, v in self.items(): + all_none = all_none or v is not None + return all_none + + +@dataclass +class FLAVAForPretrainingOutput(ModelOutput): + """ + Output from FLAVAForPretraining containing embeddings, and outputs from individual encoders. + + Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a + transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and + `text_projection` layers on `image_embeddings` and `text_embeddings` respectively. + + Args: + losses (`FLAVALosses`): + Losses for FLAVA Pretraining. Check `FLAVALosses` class description for the information on the keys. + image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. + image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FLAVAImageModel`]. + text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FLAVATextModel`]. + text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FLAVATextModel`]. + multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. + multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The output of the [`FLAVAMultimodalModel`]. + + image_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. Uses `bool_masked_pos` + to create masked images. + image_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FLAVAImageModel`]. Uses `bool_masked_pos` to create masked images. + text_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids_masked` are present): + The text embeddings which are basically the pooled output of [`FLAVATextModel`]. + text_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present): + The output of the [`FLAVATextModel`]. + multimodal_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present): + The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. + multimodal_masked_output(`BaseModelOutputWithPooling`, returned when `input_ids_masked` and `pixel_values` are present): + The output of the [`FLAVAMultimodalModel`]. + + + mim_logits: + (`torch.FloatTensor` of shape `(batch_size, image_num_patches, image_vocab_size)`, *optional*, returned + when `pixel_values` are present and `input_ids_masked` are not): The logits for MIM unimodal loss. Uses + `book_masked_pos` to get masked images. + mlm_logits: + (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)`, *optional*, returned when + `input_ids_masked` are present and `pixel_values` are not): The logits for MLM unimodal loss. + itm_logits: + (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and + `pixel_values` are present): The logits for ITM loss. Note that ITM loss is calculated on masked pairs in + FLAVA. + mmm_image_logits: + (`torch.FloatTensor` of shape `(batch_size, image_num_patches, image_vocab_size)`, *optional*, returned + when `pixel_values` and `input_ids_masked` are present): The logits for MMM image multimodal loss. Uses + `book_masked_pos` to get masked images. + mmm_text_logits: + (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)`, *optional*, returned when + `pixel_values` and `input_ids_masked` are present): The logits for MMM text multimodal loss. Uses + `book_masked_pos` to get masked images. + contrastive_logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's + `image_projection` and `text_projection` layers respectively. This represents the image-text similarity + scores. This is calculated on unmasked images and texts. + contrastive_logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's + `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and + texts. + """ + + losses: FLAVALosses = None + image_embeddings: Optional[torch.FloatTensor] = None + image_output: Optional[BaseModelOutputWithPooling] = None + text_embeddings: Optional[torch.FloatTensor] = None + text_output: Optional[BaseModelOutputWithPooling] = None + multimodal_embeddings: Optional[torch.FloatTensor] = None + multimodal_output: Optional[BaseModelOutputWithPooling] = None + image_masked_embeddings: Optional[torch.FloatTensor] = None + image_masked_output: Optional[BaseModelOutputWithPooling] = None + text_masked_embeddings: Optional[torch.FloatTensor] = None + text_masked_output: Optional[BaseModelOutputWithPooling] = None + multimodal_masked_embeddings: Optional[torch.FloatTensor] = None + multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None + mim_logits: Optional[torch.FloatTensor] = None + mlm_logits: Optional[torch.FloatTensor] = None + itm_logits: Optional[torch.FloatTensor] = None + contrastive_logits_per_image: Optional[torch.FloatTensor] = None + contrastive_logits_per_text: Optional[torch.FloatTensor] = None + mmm_image_logits: Optional[torch.FloatTensor] = None + mmm_text_logits: Optional[torch.FloatTensor] = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py + + +class FLAVAImageEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config: FLAVAImageConfig, use_mask_token: bool = False) -> None: + super().__init__() + + use_mask_token = use_mask_token or config.mask_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174 + """ + + npatch = embeddings.shape[1] - 1 + N = self.position_embeddings.shape[1] - 1 + if npatch == N and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + batch_size, seq_len, _ = embeddings.size() + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # B X H X W = B X HW + if bool_masked_pos.dim() == 3: + bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + + """ + + def __init__( + self, + image_size: int = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + num_channels: int = 3, + embed_dim: int = 768, + ): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +class FLAVATextEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long), + persistent=False, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + ): + input_shape = input_ids.size() + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class FLAVASelfAttention(nn.Module): + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class FLAVASelfOutput(nn.Module): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class FLAVAAttention(nn.Module): + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + self.attention = FLAVASelfAttention(config) + self.output = FLAVASelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention( + hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions + ) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class FLAVAIntermediate(nn.Module): + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class FLAVAOutput(nn.Module): + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class FLAVALayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: FLAVAPossibleConfigs) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = FLAVAAttention(config) + self.intermediate = FLAVAIntermediate(config) + self.output = FLAVAOutput(config) + + # TODO: Check fp32 layer norm possiblity + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class FLAVAEncoder(nn.Module): + def __init__(self, config: FLAVAConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([FLAVALayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class FLAVAPooler(nn.Module): + def __init__(self, config: FLAVAPossibleConfigs): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +FLAVA_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`{config}`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +FLAVA_INPUTS_DOCSTRING_COMMON = r""" + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`FLAVAFeatureExtractor`]. See + [`FLAVAFeatureExtractor.__call__`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. +""" + +FLAVA_IMAGE_INPUTS_DOCSTRING = ( + r""" + Args: +""" + + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + + FLAVA_INPUTS_DOCSTRING_COMMON +) + +FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r""" + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) +""" +FLAVA_TEXT_INPUTS_DOCSTRING = ( + r""" + Args: +""" + + FLAVA_TEXT_INPUTS_DOCSTRING_BASE + + FLAVA_INPUTS_DOCSTRING_COMMON +) + +FLAVA_MULTIMODAL_INPUTS_DOCSTRING = ( + r""" + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`): + The concatenated hidden states of unimodal encoders. +""" + + FLAVA_INPUTS_DOCSTRING_COMMON +) + +FLAVA_MODEL_INPUTS_DOCSTRING = ( + r""" + Args: +""" + + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + + FLAVA_TEXT_INPUTS_DOCSTRING_BASE + + FLAVA_INPUTS_DOCSTRING_COMMON + + r""" + skip_multimodal_encoder (*bool*, *optional*): + Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used. +""" +) + + +FLAVA_PRETRAINING_INPUTS_DOCSTRING = ( + r""" + Args: + input_ids_masked (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task + to be used with MLM. Indices can be obtained using [`BertTokenizer`] along with + [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + +""" + + FLAVA_TEXT_INPUTS_DOCSTRING_BASE + + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + + r""" + image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*): + Mask to avoid performing attention on padding token indices specifically for images. Mask values selected + in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + + skip_unmasked_multimodal_encoder (*bool*, *optional*): + Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked + multimodal embeddings or outputs as of now. + + mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*): + Labels for computing the left-to-right language and multimodal masked modeling loss (next word + prediction). Indices should be in `[-100, 0, ..., text_config.vocab_size]` (see `input_ids` docstring) + Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with + labels n `[0, ..., text_config.vocab_size]` + + mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*): + Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, + ..., image_config.vocab_size]`. Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels n `[0, ..., image_config.vocab_size]`. See [`FLAVACodebook`] + to understand how to generate mim_labels. + + itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they + match. The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well. +""" + + FLAVA_INPUTS_DOCSTRING_COMMON +) + + +class FLAVAPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = FLAVAConfig + base_model_prefix = "flava" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module: FLAVAEncoder, value: bool = False) -> None: + if isinstance(module, FLAVAEncoder): + module.gradient_checkpointing = value + + +@add_start_docstrings( + "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.", + FLAVA_START_DOCSTRING.format(config="FLAVAImageConfig"), +) +class FLAVAImageModel(FLAVAPreTrainedModel): + config_class = FLAVAImageConfig + base_model_prefix = "flava.image_model" + main_input_name = "pixel_values" + + def __init__(self, config: FLAVAImageConfig, add_pooling_layer: bool = True): + super().__init__(config) + + self.config = config + + self.embeddings = FLAVAImageEmbeddings(config) + self.encoder = FLAVAEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FLAVAPooler(config) if add_pooling_layer else None + + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embeddings + + def set_input_embeddings(self, value: nn.Module): + self.embeddings.patch_embeddings = value + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches")) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC, + modality="vision", + expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.", + FLAVA_START_DOCSTRING.format(config="FLAVATextConfig"), +) +class FLAVATextModel(FLAVAPreTrainedModel): + config_class = FLAVATextConfig + base_model_prefix = "flava.text_model" + + def __init__(self, config: FLAVATextConfig, add_pooling_layer: bool = True): + super().__init__(config) + self.config = config + + self.embeddings = FLAVATextEmbeddings(config) + self.encoder = FLAVAEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FLAVAPooler(config) if add_pooling_layer else None + + self.post_init() + + def get_input_embeddings(self) -> PatchEmbeddings: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: nn.Module): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=input_ids.device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, input_ids.device + ) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.", + FLAVA_START_DOCSTRING.format(config="FLAVAMultimodalConfig"), +) +class FLAVAMultimodalModel(FLAVAPreTrainedModel): + config_class = FLAVAMultimodalConfig + base_model_prefix = "flava.multimodal_model" + main_input_name = "hidden_states" + + def __init__(self, config: FLAVAMultimodalConfig, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.use_cls_token = self.config.use_cls_token + if self.use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + self.encoder = FLAVAEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = FLAVAPooler(config) if add_pooling_layer else None + + self.post_init() + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward( + FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len") + ) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC, + ) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length, _ = hidden_states.size() + + if self.use_cls_token: + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + hidden_states = torch.cat((cls_tokens, hidden_states), dim=1) + seq_length += 1 + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states.device + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.", + FLAVA_START_DOCSTRING.format(config="FLAVAConfig"), +) +class FLAVAModel(FLAVAPreTrainedModel): + config_class = FLAVAConfig + + def __init__(self, config: FLAVAConfig): + super().__init__(config) + + if not isinstance(config.text_config, FLAVATextConfig): + raise ValueError( + f"config.text_config is expected to be of type FLAVATextConfig but is of type {type(config.text_config)}." + ) + + if not isinstance(config.image_config, FLAVAImageConfig): + raise ValueError( + f"config.image_config is expected to be of type FLAVAImageConfig but is of type {type(config.image_config)}." + ) + + if not isinstance(config.multimodal_config, FLAVAMultimodalConfig): + raise ValueError( + "config.multimodal_config is expected to be of type FLAVAMultimodalConfig but " + + f"is of type {type(config.multimodal_config)}." + ) + + text_config = config.text_config + image_config = config.image_config + multimodal_config = config.multimodal_config + + self.projection_dim = config.projection_dim + self.text_hidden_size = text_config.hidden_size + self.image_hidden_size = image_config.hidden_size + self.mm_hidden_size = multimodal_config.hidden_size + + self.text_model = FLAVATextModel(text_config) + self.image_model = FLAVAImageModel(image_config) + self.multimodal_model = FLAVAMultimodalModel(multimodal_config) + + self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim) + self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim) + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + + self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size) + self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length")) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`FLAVATextModel`]. + + Examples: + + ```python + >>> from transformers import FLAVAProcessor, FLAVAModel + + >>> model = FLAVAModel.from_pretrained("{0}") + >>> processor = FLAVAProcessor.from_pretrained("{0}") + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt" + ... ) + >>> text_features = model.get_text_features(**inputs) + ```""".format( + _CHECKPOINT_FOR_DOC + ) + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[0] # last_hidden_state + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches")) + def get_image_features( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: Optional[bool] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`FLAVAImageModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FLAVAProcessor, FLAVAModel + + >>> model = FLAVAModel.from_pretrained("{0}") + >>> processor = FLAVAProcessor.from_pretrained("{0}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""".format( + _CHECKPOINT_FOR_DOC + ) + image_outputs = self.image_model( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = image_outputs[0] # last_hidden_state + image_features = self.image_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward( + FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len") + ) + @replace_return_docstrings(output_type=FLAVAModelOutput, config_class=FLAVAConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + skip_multimodal_encoder: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: bool = True, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, FLAVAOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FLAVAProcessor, FLAVAModel + + >>> model = FLAVAModel.from_pretrained("aps/flava-full") + >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.contrastive_logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.return_dict + assert output_hidden_states is True, "FLAVA model requires hidden states to work." + image_embeddings = None + image_states = None + image_mm_projection = None + image_output = None + if pixel_values is not None: + image_output = self.image_model( + pixel_values=pixel_values, + bool_masked_pos=bool_masked_pos, + attention_mask=image_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings, image_states = image_output[0], image_output[2] + # Note that these states don't use final layernorm in the transformer model + image_mm_projection = self.image_to_mm_projection(image_states[-1]) + + text_embeddings = None + text_states = None + text_mm_projection = None + text_output = None + if input_ids is not None: + text_output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeddings, text_states = text_output[0], text_output[2] + # Note that these states don't use final layernorm in the transformer model + text_mm_projection = self.text_to_mm_projection(text_states[-1]) + + multimodal_embeddings = None + multimodal_output = None + if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder: + multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1) + multimodal_output = self.multimodal_model(multimodal_input) + multimodal_embeddings = multimodal_output[0] + + return FLAVAModelOutput( + image_embeddings=image_embeddings, + image_output=image_output, + text_embeddings=text_embeddings, + text_output=text_output, + multimodal_embeddings=multimodal_embeddings, + multimodal_output=multimodal_output, + ) + + +class FLAVACodebookBlock(nn.Module): + def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs): + super().__init__() + + n_hid = out_size // 4 + self.post_gain = 1 / (num_layers**2) + + if in_size != out_size: + self.id_path = _build_codebook_conv2d(in_size, out_size, kernel_size=1) + else: + self.id_path = nn.Identity() + + self.res_path = nn.Sequential( + OrderedDict( + [ + ("relu_1", nn.ReLU()), + ("conv_1", _build_codebook_conv2d(in_size, n_hid, kernel_size=3)), + ("relu_2", nn.ReLU()), + ("conv_2", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)), + ("relu_3", nn.ReLU()), + ("conv_3", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)), + ("relu_4", nn.ReLU()), + ("conv_4", _build_codebook_conv2d(n_hid, out_size, kernel_size=1)), + ] + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +# Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42 +@add_start_docstrings( + """ + The FLAVA's codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used to + generate image tokens for an image based on DALL-E's vocab. To be used to generate labels for MIM. Use + `get_codebook_indices` to get image tokens for an image. + """, + FLAVA_START_DOCSTRING.format(config="FLAVACodebookConfig"), +) +class FLAVACodebook(FLAVAPreTrainedModel): + base_model_prefix = "" + config_class = FLAVACodebookConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: FLAVACodebookConfig, + **kwargs: Any, + ): + super().__init__(config) + + self.config = config + self.num_groups = config.num_groups + self.input_channels = config.input_channels + self.num_blocks_per_group = config.num_blocks_per_group + self.hidden_size = config.hidden_size + self.vocab_size = config.vocab_size + + num_layers = self.num_groups * self.num_blocks_per_group + output_conv = _build_codebook_conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1) + + self.blocks = nn.Sequential( + OrderedDict( + [ + ("input", _build_codebook_conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7)), + ( + "group_1", + self._create_group( + num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 1 * self.hidden_size + ), + ), + ( + "group_2", + self._create_group( + num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 2 * self.hidden_size + ), + ), + ( + "group_3", + self._create_group( + num_layers, self.num_blocks_per_group, 2 * self.hidden_size, 4 * self.hidden_size + ), + ), + ( + "group_4", + self._create_group( + num_layers, + self.num_blocks_per_group, + 4 * self.hidden_size, + 8 * self.hidden_size, + use_pool=False, + ), + ), + ( + "output", + nn.Sequential(OrderedDict([("relu", nn.ReLU()), ("conv", output_conv)])), + ), + ] + ) + ) + self.post_init() + + if self.config.freeze: + self._freeze() + + def _freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def _create_group( + self, + num_layers: int, + num_blocks_per_group: int, + in_size: int, + hidden_size: int, + use_pool: bool = True, + ): + blocks = OrderedDict() + for i in range(num_blocks_per_group): + if i == 0: + blocks[f"block_{i+1}"] = FLAVACodebookBlock(in_size, hidden_size, num_layers) + else: + blocks[f"block_{i+1}"] = FLAVACodebookBlock(hidden_size, hidden_size, num_layers) + + if use_pool: + blocks["pool"] = nn.MaxPool2d(kernel_size=2) + + return nn.Sequential(blocks) + + def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See + [`FLAVACodebookFeatureExtractor.__call__`] for details. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook + + >>> model = FLAVAModel.from_pretrained("{0}") + >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = feature_extractor(image, return_mask=True, return_tensors="pt") + + >>> outputs = model.get_codebook_indices(**inputs) + ``` + """.format( + _CHECKPOINT_FOR_CODEBOOK_DOC + ) + z_logits = self.blocks(pixel_values) + return torch.argmax(z_logits, axis=1) + + def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor: + z_logits = self.blocks(pixel_values) + return nn.Softmax(dim=1)(z_logits) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See + [`FLAVACodebookFeatureExtractor.__call__`] for details. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook + + >>> model = FLAVAModel.from_pretrained("{0}") + >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = feature_extractor([image], return_tensors="pt") + + >>> outputs = model(**inputs) + >>> print(outputs.shape) + (1, 196) + ``` + """.format( + _CHECKPOINT_FOR_CODEBOOK_DOC + ) + if len(pixel_values.shape) != 4: + raise ValueError(f"input shape {pixel_values.shape} is not 4d") + if pixel_values.shape[1] != self.input_channels: + raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}") + return self.blocks(pixel_values) + + +class FLAVAPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class FLAVAMaskedPredictionHead(nn.Module): + def __init__(self, config, weight=None): + super().__init__() + self.config = config + self.transform = FLAVAPredictionHeadTransform(config) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + if weight is not None: + self.decoder.weight = weight + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, x): + x = self.transform(x) + x = self.decoder(x) + return x + + +class FLAVAITMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pooler = FLAVAPooler(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, x): + x = self.pooler(x) + x = self.seq_relationship(x) + return x + + +class FLAVAGlobalContrastiveHead(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.global_backprop_contrastive = config.global_backprop_contrastive + + def forward(self, image_embeddings, text_embeddings, logit_scale): + temperature = torch.exp(logit_scale) + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device) + image_embeddings_all = [image_embeddings] + text_embeddings_all = [text_embeddings] + else: + local_batch_size = image_embeddings.size(0) + world_size = torch.distributed.get_world_size() + + if self.global_backprop_contrastive: + image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings) + text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings) + else: + image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)] + text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)] + torch.distributed.all_gather(image_embeddings_all, image_embeddings) + torch.distributed.all_gather(text_embeddings_all, text_embeddings) + + labels = local_batch_size * torch.distributed.get_rank() + torch.arange( + local_batch_size, device=image_embeddings.device + ) + + image_embeddings_all = torch.cat(image_embeddings_all) + text_embeddings_all = torch.cat(text_embeddings_all) + + logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature + logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature + + return logits_per_image, logits_per_text, labels + + +@add_start_docstrings( + """ + The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs. + """, + FLAVA_START_DOCSTRING.format(config="FLAVAConfig"), +) +class FLAVAForPretraining(FLAVAPreTrainedModel): + def __init__(self, config: FLAVAConfig): + super().__init__(config) + self.flava = FLAVAModel(config) + + # Levarage text and image encoder configs to create the masked + # head since it has the right vocab + self.mim_head = FLAVAMaskedPredictionHead(config.image_config) + self.mlm_head = FLAVAMaskedPredictionHead(config.text_config) + self.itm_head = FLAVAITMHead(config) + self.mmm_image_head = FLAVAMaskedPredictionHead(config.image_config) + self.mmm_text_head = FLAVAMaskedPredictionHead(config.text_config) + self.global_contrastive_head = FLAVAGlobalContrastiveHead(config) + + self.image_vocab_size = config.image_config.vocab_size + self.text_vocab_size = config.text_config.vocab_size + self.mlm_weight = config.mlm_weight + self.mim_weight = config.mim_weight + self.global_contrastive_weight = config.global_contrastive_weight + self.ce_ignore_index = config.ce_ignore_index + self.itm_weight = config.itm_weight + self.mmm_image_weight = config.mmm_image_weight + self.mmm_text_weight = config.mmm_text_weight + self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder + + def _resize_to_2d(self, x: torch.Tensor): + if x.dim() > 2: + x = x.view(x.size(0), -1) + return x + + @add_start_docstrings_to_model_forward( + FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches") + ) + @replace_return_docstrings(output_type=FLAVAForPretrainingOutput, config_class=FLAVAConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_ids_masked: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_attention_mask: Optional[torch.Tensor] = None, + # TODO: Move to config with better name + skip_unmasked_multimodal_encoder: bool = None, + mlm_labels: Optional[torch.Tensor] = None, + mim_labels: Optional[torch.Tensor] = None, + itm_labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: bool = True, + return_dict: Optional[bool] = None, + ): + """ + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook, FLAVAForPretraining, FLAVAProcessor + + >>> codebook = FLAVACodebook.from_pretrained("aps/flava-codebook") + >>> codebook_feature_extractor = FLAVACodebookFeatureExtractor.from_pretrained("aps/flava-codebook") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = codebook_feature_extractor([image], return_tensors="pt") + + >>> mim_labels = codebook.get_codebook_indices(**inputs) + + >>> model = FLAVAForPretraining.from_pretrained("aps/flava-full") + >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full") + + >>> text = ["a photo of a cat"] + + >>> inputs = processor( + ... images=[image], text=text, return_masks=True, padding=True, max_length=77, return_tensors="pt" + ... ) + >>> inputs["mim_labels"] = mim_labels + + >>> output = model(**inputs) + ``` + + Return: + + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + skip_unmasked_multimodal_encoder = ( + skip_unmasked_multimodal_encoder + if skip_unmasked_multimodal_encoder is not None + else self.skip_unmasked_multimodal_encoder + ) + + flava_output = self.flava( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + image_attention_mask=image_attention_mask, + # Don't need unmasked multimodal embedding for anything so skip it + # NOTE: ITM uses masked version + skip_multimodal_encoder=skip_unmasked_multimodal_encoder, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + # Pass true to have deterministic outputs + return_dict=True, + ) + + flava_masked_output = self.flava( + input_ids=input_ids_masked, + pixel_values=pixel_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + image_attention_mask=image_attention_mask, + bool_masked_pos=bool_masked_pos, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + pos_mask = None + + image_embeddings = flava_output.image_embeddings + text_embeddings = flava_output.text_embeddings + image_masked_embeddings = flava_masked_output.image_embeddings + text_masked_embeddings = flava_masked_output.text_embeddings + multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings + mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None + mim_logits = ( + mlm_logits + ) = mmm_text_logits = mmm_image_logits = itm_logits = logits_per_image = logits_per_text = None + + # Unimodal MIM Loss + # If multimodal embeddings are present, we will calculate MMM loss + if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None: + sequence_for_image = image_masked_embeddings + if mim_labels is not None: + mim_labels = self._resize_to_2d(mim_labels) + bool_masked_pos = self._resize_to_2d(bool_masked_pos) + mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index + + sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :] + masked_tokens = mim_labels.ne(self.ce_ignore_index) + mim_labels_filtered = mim_labels[masked_tokens] + sequence_for_image = sequence_for_image[masked_tokens, :] + mim_logits = self.mim_head(sequence_for_image) + mim_loss = nn.functional.cross_entropy( + mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1) + ) + mim_loss *= self.mim_weight + else: + mim_logits = self.mim_head(sequence_for_image) + + # Unimodal MLM Loss + if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None: + sequence_for_text = text_masked_embeddings + if mlm_labels is not None: + mlm_labels = self._resize_to_2d(mlm_labels) + sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :] + masked_tokens = mlm_labels.ne(self.ce_ignore_index) + mlm_labels_filtered = mlm_labels[masked_tokens] + sequence_for_text = sequence_for_text[masked_tokens, :] + mlm_logits = self.mlm_head(sequence_for_text) + mlm_loss = nn.functional.cross_entropy( + mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1) + ) + mlm_loss *= self.mlm_weight + else: + mlm_logits = self.mlm_head(sequence_for_text) + + # ITM Loss + if self.itm_weight > 0 and multimodal_masked_embeddings is not None: + itm_logits = self.itm_head(multimodal_masked_embeddings) + + if itm_labels is not None: + pos_pairs = itm_labels.ne(0) + pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True])) + itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels) + itm_loss *= self.itm_weight + + if multimodal_masked_embeddings is not None: + multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask] + + if mlm_labels is not None: + mlm_labels = mlm_labels[pos_mask] + + if mim_labels is not None: + mim_labels = mim_labels[pos_mask] + + # MMM Image Loss + if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0: + sequence_for_image = multimodal_masked_embeddings + end_index = image_masked_embeddings.size(1) - 1 + sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :] + + if pos_mask is not None: + sequence_for_image = sequence_for_image[pos_mask] + if mim_labels is not None: + mim_labels = self._resize_to_2d(mim_labels) + bool_masked_pos = self._resize_to_2d(bool_masked_pos) + mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index + + masked_tokens = mim_labels.ne(self.ce_ignore_index) + mim_labels_filtered = mim_labels[masked_tokens] + sequence_for_image = sequence_for_image[masked_tokens, :] + mmm_image_logits = self.mmm_image_head(sequence_for_image) + mmm_image_loss = nn.functional.cross_entropy( + mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1) + ) + mmm_image_loss *= self.mmm_image_weight + else: + mmm_image_logits = self.mmm_image_head(sequence_for_image) + + # MMM Text Loss + if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0: + sequence_for_text = multimodal_masked_embeddings + sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :] + if pos_mask is not None: + sequence_for_text = sequence_for_text[pos_mask] + + if mlm_labels is not None: + mlm_labels = self._resize_to_2d(mlm_labels) + masked_tokens = mlm_labels.ne(self.ce_ignore_index) + mlm_labels_filtered = mlm_labels[masked_tokens] + sequence_for_text = sequence_for_text[masked_tokens, :] + mmm_text_logits = self.mmm_text_head(sequence_for_text) + mmm_text_loss = nn.functional.cross_entropy( + mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1) + ) + mmm_text_loss *= self.mmm_text_weight + else: + mmm_text_logits = self.mmm_text_head(sequence_for_text) + + # Global Contrastive Loss + if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0: + text_embedding = self.flava.text_projection(text_embeddings[:, 0, :]) + text_embedding = nn.functional.normalize(text_embedding, dim=-1) + + image_embedding = self.flava.image_projection(image_embeddings[:, 0, :]) + image_embedding = nn.functional.normalize(image_embedding, dim=-1) + + self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX) + + logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head( + image_embedding, text_embedding, self.flava.logit_scale + ) + + # Apply ITM negative mask if any + if pos_mask is not None: + logits_per_image = logits_per_image[pos_mask] + logits_per_text = logits_per_text[pos_mask] + gc_labels = gc_labels[pos_mask] + + gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels) + gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels) + gc_loss = (gc_loss_image + gc_loss_text) / 2 + gc_loss *= self.global_contrastive_weight + + flava_losses = FLAVALosses( + mim=mim_loss, + mlm=mlm_loss, + itm=itm_loss, + global_contrastive=gc_loss, + mmm_image=mmm_image_loss, + mmm_text=mmm_text_loss, + ) + + if not return_dict: + output = ( + image_embeddings, + flava_output.image_output, + text_embeddings, + flava_output.text_output, + flava_output.multimodal_embeddings, + flava_output.multimodal_output, + image_masked_embeddings, + flava_masked_output.image_output, + text_masked_embeddings, + flava_masked_output.text_output, + multimodal_masked_embeddings, + flava_masked_output.multimodal_output, + mim_logits, + mlm_logits, + itm_logits, + logits_per_image, + logits_per_image, + mmm_image_logits, + mmm_text_logits, + ) + if flava_losses.all_none(): + return (flava_losses,) + output + else: + return output + + return FLAVAForPretrainingOutput( + losses=flava_losses, + image_embeddings=image_embeddings, + image_output=flava_output.image_output, + text_embeddings=text_embeddings, + text_output=flava_output.text_output, + multimodal_embeddings=flava_output.multimodal_embeddings, + multimodal_output=flava_output.multimodal_output, + image_masked_embeddings=image_masked_embeddings, + image_masked_output=flava_masked_output.image_output, + text_masked_embeddings=text_masked_embeddings, + text_masked_output=flava_masked_output.text_output, + multimodal_masked_embeddings=multimodal_masked_embeddings, + multimodal_masked_output=flava_masked_output.multimodal_output, + mim_logits=mim_logits, + mlm_logits=mlm_logits, + itm_logits=itm_logits, + contrastive_logits_per_image=logits_per_image, + contrastive_logits_per_text=logits_per_text, + mmm_image_logits=mmm_image_logits, + mmm_text_logits=mmm_text_logits, + ) diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py new file mode 100644 index 00000000000000..e75074eed81298 --- /dev/null +++ b/src/transformers/models/flava/processing_flava.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for FLAVA +""" +from typing import List, Optional, Union + +import numpy as np +import torch +from PIL import Image + +from transformers.data.data_collator import DataCollatorForWholeWordMask, tolist + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class FLAVAProcessor(ProcessorMixin): + r""" + Constructs a FLAVA processor which wraps a FLAVA feature extractor and a FLAVA tokenizer into a single processor. + + [`FLAVAProcessor`] offers all the functionalities of [`FLAVAFeatureExtractor`] and [`FLAVATokenizerFast`]. See the + [`~FLAVAProcessor.__call__`] and [`~FLAVAProcessor.decode`] for more information. + + Args: + feature_extractor ([`FLAVAFeatureExtractor`]): + The feature extractor is a required input. + tokenizer ([`FLAVATokenizerFast`]): + The tokenizer is a required input. + """ + feature_extractor_class = "FLAVAFeatureExtractor" + tokenizer_class = ("BertTokenizer", "BertTokenizerFast") + + def __init__(self, feature_extractor, tokenizer, mlm_probability=0.15): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self.text_masker = DataCollatorForWholeWordMask(tokenizer, mlm=True, mlm_probability=mlm_probability) + + def __call__( + self, + images: Optional[ + Union[ + Image.Image, + np.ndarray, + "torch.Tensor", + List[Image.Image], + List[np.ndarray], + List["torch.Tensor"], # noqa + ] + ] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_masks: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ): + """ + This method uses [`FLAVAFeatureExtractor.__call__`] method to prepare image(s) for the model, and + [`BertTokenizerFast.__call__`] to prepare text for the model. + + Please refer to the docstring of the above two methods for more information. Other special args are mentioned + below: + + Args: + return_mask (`bool`, *optional*, defaults to None): + If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version and + `input_ids_masked` and `mlm_labels` for MLM. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer( + text=text, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask or return_masks, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + **kwargs, + ) + if images is not None: + image_features = self.feature_extractor( + images, return_masks=return_masks, return_tensors=return_tensors, **kwargs + ) + + if return_masks and text is not None: + batch_masked = self.text_masker(tolist(encoding["input_ids"]), return_tensors=return_tensors) + encoding["input_ids_masked"] = batch_masked["input_ids"] + encoding["mlm_labels"] = batch_masked["labels"] + encoding.pop("special_tokens_mask") + + if text is not None and images is not None: + encoding.update(image_features) + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/tests/flava/__init__.py b/tests/flava/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/flava/test_feature_extraction_flava.py b/tests/flava/test_feature_extraction_flava.py new file mode 100644 index 00000000000000..f6d37cecf7696c --- /dev/null +++ b/tests/flava/test_feature_extraction_flava.py @@ -0,0 +1,414 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor + from transformers.models.flava.feature_extraction_flava import ( + FLAVA_CODEBOOK_MEAN, + FLAVA_CODEBOOK_STD, + FLAVA_IMAGE_MEAN, + FLAVA_IMAGE_STD, + ) + + +# TODO(aps): Add joint feature extractor useful for pretraining +class FLAVAFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=224, + do_center_crop=True, + crop_size=224, + resample=Image.BICUBIC, + do_normalize=True, + image_mean=FLAVA_IMAGE_MEAN, + image_std=FLAVA_IMAGE_STD, + input_size_patches=14, + total_mask_patches=75, + mask_group_max_patches=None, + mask_group_min_patches=16, + mask_group_min_aspect_ratio=0.3, + mask_group_max_aspect_ratio=None, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.do_resize = do_resize + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.size = size + self.resample = resample + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.input_size_patches = input_size_patches + self.total_mask_patches = total_mask_patches + self.mask_group_max_patches = mask_group_max_patches + self.mask_group_min_patches = mask_group_min_patches + self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio + self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio + + def prepare_feat_extract_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_resize": self.do_resize, + "size": self.size, + "resample": self.resample, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "input_size_patches": self.input_size_patches, + "total_mask_patches": self.total_mask_patches, + "mask_group_max_patches": self.mask_group_max_patches, + "mask_group_min_patches": self.mask_group_min_patches, + "mask_group_min_aspect_ratio": self.mask_group_min_aspect_ratio, + "mask_group_max_aspect_ratio": self.mask_group_min_aspect_ratio, + } + + def get_expected_image_size(self): + return (self.size, self.size) if not isinstance(self.size, tuple) else self.size + + def get_expected_mask_size(self): + return ( + (self.input_size_patches, self.input_size_patches) + if not isinstance(self.input_size_patches, tuple) + else self.input_size_patches + ) + + +@require_torch +@require_vision +class FLAVAFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = FLAVAFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = FLAVAFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "resample")) + self.assertTrue(hasattr(feature_extractor, "crop_size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "masking_generator")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") + + # Test no bool masked pos + self.assertFalse("bool_masked_pos" in encoded_images) + + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + + # Test no bool masked pos + self.assertFalse("bool_masked_pos" in encoded_images) + + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def _test_call_framework(self, instance_class, prepare_kwargs): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs) + for image in image_inputs: + self.assertIsInstance(image, instance_class) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") + + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt") + + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size() + self.assertEqual( + encoded_images.bool_masked_pos.shape, + ( + self.feature_extract_tester.batch_size, + expected_height, + expected_width, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + # Test masking + encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt") + + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size() + self.assertEqual( + encoded_images.bool_masked_pos.shape, + ( + self.feature_extract_tester.batch_size, + expected_height, + expected_width, + ), + ) + + def test_call_numpy(self): + self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True}) + + def test_call_pytorch(self): + self._test_call_framework( + torch.Tensor, + prepare_kwargs={ + "torchify": True, + }, + ) + + def test_masking(self): + # Initialize feature_extractor + random.seed(1234) + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_masks=True, return_tensors="pt") + self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75) + + +class FLAVACodebookFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=112, + do_center_crop=True, + crop_size=112, + resample=Image.LANCZOS, + do_normalize=True, + image_mean=FLAVA_CODEBOOK_MEAN, + image_std=FLAVA_CODEBOOK_STD, + do_map_pixels=True, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.do_resize = do_resize + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.resample = resample + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_map_pixels = do_map_pixels + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "resample": self.resample, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_map_pixels": self.do_map_pixels, + } + + def get_expected_image_size(self): + return (self.size, self.size) if not isinstance(self.size, tuple) else self.size + + +@require_torch +@require_vision +class FLAVACodebookFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = FLAVACodebookFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = FLAVACodebookFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "resample")) + self.assertTrue(hasattr(feature_extractor, "crop_size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + + def _test_call_framework(self, instance_class, prepare_kwargs): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs) + for image in image_inputs: + self.assertIsInstance(image, instance_class) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.pixel_values.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_numpy(self): + self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True}) + + def test_call_pytorch(self): + self._test_call_framework( + torch.Tensor, + prepare_kwargs={ + "torchify": True, + }, + ) diff --git a/tests/flava/test_modeling_flava.py b/tests/flava/test_modeling_flava.py new file mode 100644 index 00000000000000..e844c93da401e1 --- /dev/null +++ b/tests/flava/test_modeling_flava.py @@ -0,0 +1,824 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch FLAVA model. """ + + +import inspect +import os +import tempfile +import unittest + +import numpy as np + +import requests +from transformers import FLAVAConfig, FLAVAImageConfig, FLAVAMultimodalConfig, FLAVATextConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import FLAVAForPretraining, FLAVAImageModel, FLAVAModel, FLAVAMultimodalModel, FLAVATextModel + from transformers.models.flava.modeling_flava import FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import FLAVAProcessor + + +class FLAVAImageModelTester: + def __init__( + self, + parent, + batch_size=12, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=30, + patch_size=2, + num_channels=3, + qkv_bias=True, + mask_token=True, + vocab_size=8192, + ): + self.parent = parent + self.batch_size = batch_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.mask_token = mask_token + self.vocab_size = vocab_size + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + num_patches = self.image_size // self.patch_size + bool_masked_pos = (torch.rand((self.batch_size, num_patches, num_patches)) < 0.9).long() + config = self.get_config() + return config, pixel_values, bool_masked_pos + + def get_config(self): + return FLAVAImageConfig( + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + layer_norm_eps=self.layer_norm_eps, + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + qkv_bias=self.qkv_bias, + mask_token=self.mask_token, + vocab_size=self.vocab_size, + ) + + def create_and_check_model(self, config, pixel_values, bool_masked_pos): + model = FLAVAImageModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values, bool_masked_pos) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = (self.image_size, self.image_size) + patch_size = (self.patch_size, self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, bool_masked_pos = config_and_inputs + inputs_dict = {"pixel_values": pixel_values, "bool_masked_pos": bool_masked_pos} + return config, inputs_dict + + +@require_torch +class FLAVAImageModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as FLAVA does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (FLAVAImageModel,) if is_torch_available() else () + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = FLAVAImageModelTester(self) + self.config_tester = ConfigTester(self, config_class=FLAVAImageConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # FLAVA does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in FLAVA, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = (self.model_tester.image_size, self.model_tester.image_size) + patch_size = (self.model_tester.patch_size, self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_len, seq_len], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # FLAVA has a different seq_length + image_size = (self.model_tester.image_size, self.model_tester.image_size) + patch_size = (self.model_tester.patch_size, self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 1 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + # skip this test as FLAVAImageModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_from_base(self): + pass + + # skip this test as FLAVAImageModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = FLAVAImageModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class FLAVATextModelTester: + def __init__( + self, + parent, + batch_size=12, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + vocab_size=30522, + type_vocab_size=2, + max_position_embeddings=512, + position_embedding_type="absolute", + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + qkv_bias=True, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.seq_length = seq_length + self.vocab_size = vocab_size + self.type_vocab_size = type_vocab_size + self.max_position_embeddings = max_position_embeddings + self.position_embedding_type = position_embedding_type + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.pad_token_id = pad_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + + token_type_ids = None + + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask + + def get_config(self): + return FLAVATextConfig( + vocab_size=self.vocab_size, + type_vocab_size=self.type_vocab_size, + max_position_embeddings=self.max_position_embeddings, + position_embedding_type=self.position_embedding_type, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + layer_norm_eps=self.layer_norm_eps, + pad_token_id=self.pad_token_id, + qkv_bias=self.qkv_bias, + ) + + def create_and_check_model(self, config, input_ids, token_type_ids, input_mask): + model = FLAVATextModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class FLAVATextModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (FLAVATextModel,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + test_torchscript = False + + def setUp(self): + self.model_tester = FLAVATextModelTester(self) + self.config_tester = ConfigTester(self, config_class=FLAVATextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + def test_inputs_embeds(self): + # FLAVA does not use inputs_embeds + pass + + # skip this test as FLAVATextModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_from_base(self): + pass + + # skip this test as FLAVATextModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = FLAVATextModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class FLAVAMultimodalModelTester: + def __init__( + self, + parent, + batch_size=12, + seq_length=44, + use_input_mask=True, + hidden_size=768, + num_hidden_layers=6, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + qkv_bias=True, + use_cls_token=True, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.use_input_mask = use_input_mask + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.use_cls_token = use_cls_token + + def prepare_config_and_inputs(self): + hidden_states = floats_tensor([self.batch_size, self.seq_length - 1, self.hidden_size]) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + + config = self.get_config() + + return config, hidden_states, input_mask + + def get_config(self): + return FLAVAMultimodalConfig( + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + layer_norm_eps=self.layer_norm_eps, + qkv_bias=self.qkv_bias, + use_cls_token=self.use_cls_token, + ) + + def create_and_check_model(self, config, hidden_states, input_mask): + model = FLAVAMultimodalModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(hidden_states, attention_mask=input_mask) + result = model(hidden_states) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, hidden_states, input_mask = config_and_inputs + inputs_dict = {"hidden_states": hidden_states, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class FLAVAMultimodalModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (FLAVAMultimodalModel,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + test_resize_embeddings = False + test_torchscript = False + + def setUp(self): + self.model_tester = FLAVAMultimodalModelTester(self) + self.config_tester = ConfigTester( + self, config_class=FLAVAMultimodalConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["hidden_states"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model_common_attributes(self): + # No embedding in multimodal model + pass + + def test_training(self): + pass + + def test_training_gradient_checkpointing(self): + pass + + def test_inputs_embeds(self): + # FLAVA does not use inputs_embeds + pass + + # skip this test as FLAVAMultimodalModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_from_base(self): + pass + + # skip this test as FLAVAMultimodalModel has no base class and is + # not available in MODEL_MAPPING + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = FLAVAMultimodalModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class FLAVAModelTester: + def __init__(self, parent, is_training=True): + self.parent = parent + self.text_model_tester = FLAVATextModelTester(parent) + self.vision_model_tester = FLAVAImageModelTester(parent) + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, token_type_ids, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, token_type_ids, attention_mask, pixel_values + + def get_config(self): + return FLAVAConfig.from_text_vision_configs( + self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64 + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = FLAVAModel(config).to(torch_device).eval() + with torch.no_grad(): + result = model(input_ids, pixel_values, attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "pixel_values": pixel_values, + "return_loss": True, + } + return config, inputs_dict + + +@require_torch +class FLAVAModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (FLAVAModel,) if is_torch_available() else () + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + + def setUp(self): + self.model_tester = FLAVAModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # hidden_states are tested in individual model tests + def test_hidden_states_output(self): + pass + + # input_embeds are tested in individual model tests + def test_inputs_embeds(self): + pass + + # tested in individual model tests + def test_retain_grad_hidden_states_attentions(self): + pass + + # FLAVAModel does not have input/output embeddings + def test_model_common_attributes(self): + pass + + # override as the `logit_scale` parameter initilization is different for FLAVA + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + + try: + input_ids = inputs_dict["input_ids"] + pixel_values = inputs_dict["pixel_values"] # FLAVA needs pixel_values + traced_model = torch.jit.trace(model, (input_ids, pixel_values)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_load_vision_text_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Save FLAVAConfig and check if we can load FLAVAImageConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = FLAVAImageConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save FLAVAConfig and check if we can load FLAVATextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = FLAVATextConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + # overwrite from common since FLAVAModel/TFFLAVAModel return FLAVAOutput/TFFLAVAOutput + @slow + def test_model_from_pretrained(self): + for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = FLAVAModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@require_vision +@require_torch +class FLAVAModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model_name = "" + model = FLAVAForPretraining.from_pretrained(model_name).to(torch_device) + processor = FLAVAProcessor.from_pretrained(model_name) + + image = prepare_img() + inputs = processor( + text=["a photo of a cat", "a photo of a dog"], + images=[image, image], + padding="max_length", + max_length=77, + return_tensors="pt", + return_masks=True, + ).to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + self.assertEqual( + outputs.contrastive_logits_per_image.shape, + torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])), + ) + self.assertEqual( + outputs.contrastive_logits_per_text.shape, + torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), + ) + + expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device) + + self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) diff --git a/tests/flava/test_processor_flava.py b/tests/flava/test_processor_flava.py new file mode 100644 index 00000000000000..b33d42ae5fa520 --- /dev/null +++ b/tests/flava/test_processor_flava.py @@ -0,0 +1,204 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import pytest + +from transformers import BertTokenizer, BertTokenizerFast +from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES +from transformers.testing_utils import require_vision +from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available + + +if is_vision_available(): + from PIL import Image + + from transformers import FLAVAFeatureExtractor, FLAVAProcessor + from transformers.models.flava.feature_extraction_flava import FLAVA_IMAGE_MEAN, FLAVA_IMAGE_STD + + +@require_vision +class FLAVAProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write("".join([x + "\n" for x in vocab_tokens])) + + feature_extractor_map = { + "image_mean": FLAVA_IMAGE_MEAN, + "image_std": FLAVA_IMAGE_STD, + "do_normalize": True, + "do_resize": True, + "size": 224, + "do_center_crop": True, + "crop_size": 224, + "input_size_patches": 14, + "total_mask_patches": 75, + "mask_group_max_patches": None, + "mask_group_min_patches": 16, + "mask_group_min_aspect_ratio": 0.3, + "mask_group_max_aspect_ratio": None, + } + + self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.feature_extractor_file, "w", encoding="utf-8") as fp: + json.dump(feature_extractor_map, fp) + + def get_tokenizer(self, **kwargs): + return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_rust_tokenizer(self, **kwargs): + return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + + def get_feature_extractor(self, **kwargs): + return FLAVAFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_default(self): + tokenizer_slow = self.get_tokenizer() + tokenizer_fast = self.get_rust_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor_slow = FLAVAProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor) + processor_slow.save_pretrained(self.tmpdirname) + processor_slow = FLAVAProcessor.from_pretrained(self.tmpdirname, use_fast=False) + + processor_fast = FLAVAProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor) + processor_fast.save_pretrained(self.tmpdirname) + processor_fast = FLAVAProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab()) + self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab()) + self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab()) + self.assertIsInstance(processor_slow.tokenizer, BertTokenizer) + self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast) + + self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor_slow.feature_extractor, FLAVAFeatureExtractor) + self.assertIsInstance(processor_fast.feature_extractor, FLAVAFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = FLAVAProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) + + processor = FLAVAProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, BertTokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, FLAVAFeatureExtractor) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = feature_extractor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_processor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor)