diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 69717477e1f230..90ce8c8d4f861c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -214,6 +214,8 @@
title: Encoder Decoder Models
- local: model_doc/flaubert
title: FlauBERT
+ - local: model_doc/flava
+ title: FLAVA
- local: model_doc/fnet
title: FNet
- local: model_doc/fsmt
diff --git a/docs/source/en/model_doc/flava.mdx b/docs/source/en/model_doc/flava.mdx
new file mode 100644
index 00000000000000..2c9478a77dc1e7
--- /dev/null
+++ b/docs/source/en/model_doc/flava.mdx
@@ -0,0 +1,93 @@
+
+
+# FLAVA
+
+## Overview
+
+The FLAVA model was proposed in [FLAVA: A Foundational Language And Vision Alignment Model
+](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela and is accepted at CVPR 2022.
+
+The paper aims at creating a single unified foundation model which can work across vision, language
+as well as vision-and-language multimodal tasks.
+
+The abstract from the paper is the following:
+
+State-of-the-art vision and vision-and-language models rely on large-scale visio-linguistic pretraining for obtaining good performance on a variety
+of downstream tasks. Generally, such models are often either cross-modal (contrastive) or multi-modal
+(with earlier fusion) but not both; and they often only target specific modalities or tasks. A promising
+direction would be to use a single holistic universal model, as a "foundation", that targets all modalities
+at once -- a true vision and language foundation model should be good at vision tasks, language tasks, and
+cross- and multi-modal vision and language tasks. We introduce FLAVA as such a model and demonstrate
+impressive performance on a wide range of 35 tasks spanning these target modalities.
+
+
+
+
+This model was contributed by [aps](https://huggingface.co/aps).
+
+
+
+## FLAVAConfig
+
+[[autodoc]] FLAVAConfig
+ - from_configs
+
+## FLAVATextConfig
+
+[[autodoc]] FLAVATextConfig
+
+## FLAVAImageConfig
+
+[[autodoc]] FLAVAImageConfig
+
+## FLAVAMultimodalConfig
+
+[[autodoc]] FLAVAMultimodalConfig
+
+## FLAVACodebookConfig
+
+[[autodoc]] FLAVACodebookConfig
+
+## FLAVAForPretraining
+
+[[autodoc]] FLAVAForPretraining
+ - forward
+
+## FLAVAModel
+
+[[autodoc]] FLAVAModel
+ - forward
+ - get_text_features
+ - get_image_features
+
+## FLAVACodebook
+
+[[autodoc]] FLAVACodebook
+ - forward
+ - get_codebook_indices
+ - get_codebook_probs
+
+## FLAVATextModel
+
+[[autodoc]] FLAVATextModel
+ - forward
+
+## FLAVAImageModel
+
+[[autodoc]] FLAVAImageModel
+ - forward
+
+## FLAVAMultimodalModel
+
+[[autodoc]] FLAVAMultimodalModel
+ - forward
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 841b830c3b1db7..4b03299b61f7e1 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -191,6 +191,17 @@
"models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"],
"models.encoder_decoder": ["EncoderDecoderConfig"],
"models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"],
+ "models.flava": [
+ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "FLAVACodebookConfig",
+ "FLAVACodebookFeatureExtractor",
+ "FLAVAConfig",
+ "FLAVAFeatureExtractor",
+ "FLAVAImageConfig",
+ "FLAVAMultimodalConfig",
+ "FLAVAProcessor",
+ "FLAVATextConfig",
+ ],
"models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig", "FNetTokenizer"],
"models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"],
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
@@ -986,6 +997,19 @@
"FlaubertWithLMHeadModel",
]
)
+ _import_structure["models.flava"].extend(
+ [
+ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "FLAVACodebook",
+ "FLAVAForPretraining",
+ "FLAVAImageModel",
+ "FLAVALayer",
+ "FLAVAModel",
+ "FLAVAMultimodalModel",
+ "FLAVAPreTrainedModel",
+ "FLAVATextModel",
+ ]
+ )
_import_structure["models.fnet"].extend(
[
"FNET_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2565,6 +2589,17 @@
from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer
from .models.encoder_decoder import EncoderDecoderConfig
from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer
+ from .models.flava import (
+ FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ FLAVACodebookConfig,
+ FLAVACodebookFeatureExtractor,
+ FLAVAConfig,
+ FLAVAFeatureExtractor,
+ FLAVAImageConfig,
+ FLAVAMultimodalConfig,
+ FLAVAProcessor,
+ FLAVATextConfig,
+ )
from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig, FNetTokenizer
from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
@@ -3238,6 +3273,14 @@
FlaubertModel,
FlaubertWithLMHeadModel,
)
+ from .models.flava import (
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FLAVAForPretraining,
+ FLAVAImageModel,
+ FLAVAModel,
+ FLAVAPreTrainedModel,
+ FLAVATextModel,
+ )
from .models.fnet import (
FNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FNetForMaskedLM,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7045f18c556154..1432fd61cc545c 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -54,6 +54,7 @@
electra,
encoder_decoder,
flaubert,
+ flava,
fnet,
fsmt,
funnel,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 1899c3c249fafa..9eb08492768218 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -62,6 +62,7 @@
("canine", "CanineConfig"),
("roformer", "RoFormerConfig"),
("clip", "CLIPConfig"),
+ ("flava", "FLAVAConfig"),
("bigbird_pegasus", "BigBirdPegasusConfig"),
("deit", "DeiTConfig"),
("luke", "LukeConfig"),
@@ -164,6 +165,7 @@
("canine", "CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("clip", "CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("flava", "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("bigbird_pegasus", "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -258,6 +260,7 @@
("canine", "Canine"),
("roformer", "RoFormer"),
("clip", "CLIP"),
+ ("flava", "flava"),
("bigbird_pegasus", "BigBirdPegasus"),
("deit", "DeiT"),
("luke", "LUKE"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index dc83fc133fad97..3221e1b286430e 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -47,6 +47,7 @@
("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
+ ("flava", "FLAVAFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index b0cfb47672491a..c21dfd1ed974ec 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -59,6 +59,7 @@
("canine", "CanineModel"),
("roformer", "RoFormerModel"),
("clip", "CLIPModel"),
+ ("flava", "FLAVAModel"),
("bigbird_pegasus", "BigBirdPegasusModel"),
("deit", "DeiTModel"),
("luke", "LukeModel"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index b51ef9ef312e10..23ec7c400269ee 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -38,6 +38,7 @@
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("clip", "CLIPProcessor"),
+ ("flava", "FLAVAProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("speech_to_text", "Speech2TextProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index a604841340967d..f1ad8bb4b407e6 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -218,6 +218,13 @@
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
+ # (
+ # "flava",
+ # (
+ # "CLIPTokenizer",
+ # "CLIPTokenizerFast" if is_tokenizers_available() else None,
+ # ),
+ # ),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
(
"perceiver",
diff --git a/src/transformers/models/flava/__init__.py b/src/transformers/models/flava/__init__.py
new file mode 100644
index 00000000000000..7027ae3798e82f
--- /dev/null
+++ b/src/transformers/models/flava/__init__.py
@@ -0,0 +1,77 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_flava": [
+ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP",
+ "FLAVACodebookConfig",
+ "FLAVAConfig",
+ "FLAVAImageConfig",
+ "FLAVAMultimodalConfig",
+ "FLAVATextConfig",
+ ],
+}
+
+if is_vision_available():
+ _import_structure["feature_extraction_flava"] = ["FLAVACodebookFeatureExtractor", "FLAVAFeatureExtractor"]
+ _import_structure["processing_flava"] = ["FLAVAProcessor"]
+
+if is_torch_available():
+ _import_structure["modeling_flava"] = [
+ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "FLAVACodebook",
+ "FLAVAForPretraining",
+ "FLAVAImageModel",
+ "FLAVAModel",
+ "FLAVAMultimodalModel",
+ "FLAVAPreTrainedModel",
+ "FLAVATextModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_flava import (
+ FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP,
+ FLAVAConfig,
+ FLAVAImageConfig,
+ FLAVAMultimodalConfig,
+ FLAVATextConfig,
+ )
+
+ if is_vision_available():
+ from .feature_extraction_flava import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor
+ from .processing_flava import FLAVAProcessor
+
+ if is_torch_available():
+ from .modeling_flava import (
+ FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
+ FLAVACodebook,
+ FLAVAImageModel,
+ FLAVAModel,
+ FLAVAMultimodalModel,
+ FLAVAPreTrainedModel,
+ FLAVATextModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/flava/configuration_flava.py b/src/transformers/models/flava/configuration_flava.py
new file mode 100644
index 00000000000000..179f1d4421beb7
--- /dev/null
+++ b/src/transformers/models/flava/configuration_flava.py
@@ -0,0 +1,599 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" FLAVA model configuration"""
+
+import copy
+import os
+from typing import Union
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "flava-full": "https://huggingface.co/aps/flava-full/resolve/main/config.json",
+}
+
+
+class FLAVAImageConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FLAVAImageModel`]. It is used to instantiate an
+ FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the FLAVA
+ [full](https://huggingface.co/aps/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to `224`):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to `16`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to `3`):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ mask_token (`bool`, *optional*, defaults to True):
+ Whether to use a mask token or not. Used in MIM loss.
+ vocab_size (`int`, *optional*, defaults to 8192):
+ Vocabulary size of the [`FLAVACodebook`] used in conjunction with [`FLAVAImageModel`] for MIM.
+
+ Example:
+
+ ```python
+ >>> from transformers import FLAVAImageModel, FLAVAImageConfig
+
+ >>> # Initializing a FLAVAImageModel with style configuration
+ >>> configuration = FLAVAImageConfig()
+
+ >>> # Initializing a FLAVAImageModel model from the style configuration
+ >>> model = FLAVAImageModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "flava_image_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ mask_token=True,
+ vocab_size=8192,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.mask_token = mask_token
+ self.vocab_size = vocab_size
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the image config dict if we are loading from FLAVAConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["image_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FLAVATextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FLAVATextModel`]. It is used to instantiate an
+ FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the FLAVA
+ [full](https://huggingface.co/aps/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`FLAVATextModel`].
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`FLAVATextModel`]. Note that even though
+ text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is
+ used similar to RoBERTa.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to `224`):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to `16`):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to `3`):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import FLAVATextModel, FLAVATextConfig
+
+ >>> # Initializing a FLAVATextModel with style configuration
+ >>> configuration = FLAVATextConfig()
+
+ >>> # Initializing a FLAVATextConfig from the style configuration
+ >>> model = FLAVATextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "flava_text_model"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ type_vocab_size=2,
+ max_position_embeddings=512,
+ position_embedding_type="absolute",
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ qkv_bias=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.position_embedding_type = position_embedding_type
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.pad_token_id = pad_token_id
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from FLAVAConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["text_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FLAVAMultimodalConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FLAVAMultimodalModel`]. It is used to instantiate
+ an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the FLAVA
+ [full](https://huggingface.co/aps/flava-full) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ use_cls_token (`bool`, *optional*, defaults to `True`):
+ Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import FLAVAMultimodalModel, FLAVAMultimodalConfig
+
+ >>> # Initializing a FLAVAMultimodalModel with style configuration
+ >>> configuration = FLAVAMultimodalConfig()
+
+ >>> # Initializing a FLAVAMultimodalModel model from the style configuration
+ >>> model = FLAVAMultimodalModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "flava_multimodal_model"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=6,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ use_cls_token=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.use_cls_token = use_cls_token
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the image config dict if we are loading from FLAVAConfig
+ if config_dict.get("model_type") == "flava":
+ config_dict = config_dict["multimodal_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class FLAVACodebookConfig(PretrainedConfig):
+ model_type = "flava_codebook"
+
+ r"""
+ [`FLAVACodebookConfig`] is the configuration class to store the configuration of a [`FLAVACodebook`]. It is used to
+ instantiate an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the FLAVA
+ [codebook](https://huggingface.co/aps/flava-codebook) architecture
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_groups (`int`, defaults to 4):
+ Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
+ internal calculation and estimations.
+ input_channels (`int`, defaults to 3):
+ Number of channels in the image to be passed.
+ num_blocks_per_group (`int`, defaults to 2):
+ Number of conv-based blocks per group.
+ hidden_size (`int`, defaults to 256):
+ Size of hidden dim for the blocks.
+ vocab_size (`int`, defaults to 8192):
+ Size of the output vocabulary for the codebook.
+ freeze (`bool`, defaults to True):
+ Whether to freeze the weights of the model.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+ Example:
+
+ ```python
+ >>> from transformers import FLAVACodebook, FLAVACodebookConfig >>> # Initializing a FLAVACodebook with style
+ configuration >>> configuration = FLAVACodebookConfig() >>> # Initializing a FLAVACodebook model from the style
+ configuration >>> model = FLAVACodebook(configuration) >>> # Accessing the model configuration >>> configuration =
+ model.config"""
+
+ def __init__(
+ self,
+ num_groups=4,
+ input_channels=3,
+ num_blocks_per_group=2,
+ hidden_size=256,
+ vocab_size=8192,
+ freeze=True,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.num_groups = num_groups
+ self.input_channels = input_channels
+ self.num_blocks_per_group = num_blocks_per_group
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
+ self.freeze = freeze
+ self.initializer_range = initializer_range
+
+
+class FLAVAConfig(PretrainedConfig):
+ r"""
+ [`FLAVAConfig`] is the configuration class to store the configuration of a [`FLAVAModel`]. It is used to
+ instantiate FLAVA model according to the specified arguments, defining the text model, image model and multimodal
+ model configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ text_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FLAVATextConfig`].
+ image_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FLAVAImageConfig`].
+ multimodal_config_dict (`dict`, *optional*):
+ Dictionary of configuration options used to initialize [`FLAVAMultimodalConfig`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ projection_dim (`int`, *optional*, defaults to 512):
+ Dimentionality of text and image projection layers.
+ logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
+ The inital value of the *logit_scale* paramter. Default is used as per the original FLAVA/CLIP
+ implementation.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ ce_ignore_index (`int`, *optional*, defaults to -100):
+ Cross entropy index to ignore.
+ mim_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MIM unimodal loss
+ mlm_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MLM unimodal loss
+ global_contrastive_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to global contrastive cross-alignment loss.
+ itm_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to image-text matching multimodal loss.
+ mmm_image_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MMM loss's image part.
+ mmm_text_weight (`float`, *optional*, defaults to 1.0):
+ Weight to be assigned to MMM loss's text part.
+ global_backprop_contrastive (`bool`, *optional*, defaults to True):
+ Whether to use global backpropgation through all workers in contrastive loss.
+ skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to True):
+ Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses.
+
+ kwargs (*optional*):
+ Dictionary of keyword arguments.
+
+ Example:
+
+ ```python
+ >>> from transformers import FLAVAModel, FLAVAForPretraining, FLAVAConfig
+
+ >>> # Initializing a FLAVAConfig with style configuration
+ >>> configuration = FLAVAConfig()
+
+ >>> # Initializing a FLAVAModel and FLAVAForPretraining model from the style configuration
+ >>> model = FLAVAModel(configuration)
+ >>> model_pre = FLAVAForPretraining(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ >>> configuration_pre = model_pre.config
+ ```
+ """
+
+ model_type = "flava"
+ is_composition = True
+
+ def __init__(
+ self,
+ text_config_dict=None,
+ image_config_dict=None,
+ multimodal_config_dict=None,
+ hidden_size=768,
+ layer_norm_eps=1e-12,
+ projection_dim=768,
+ logit_scale_init_value=2.6592,
+ initializer_range=0.02,
+ ce_ignore_index=-100,
+ mim_weight=1.0,
+ mlm_weight=1.0,
+ global_contrastive_weight=1.0,
+ itm_weight=1.0,
+ mmm_image_weight=1.0,
+ mmm_text_weight=1.0,
+ global_backprop_contrastive=True,
+ skip_unmasked_multimodal_encoder=True,
+ **kwargs
+ ):
+ super().__init__(
+ text_config_dict=text_config_dict,
+ image_config_dict=image_config_dict,
+ multimodal_config_dict=multimodal_config_dict,
+ **kwargs,
+ )
+
+ if text_config_dict is None:
+ text_config_dict = {}
+ logger.info("text_config_dict is None. Initializing the FLAVATextConfig with default values.")
+
+ if image_config_dict is None:
+ image_config_dict = {}
+ logger.info("image_config_dict is None. initializing the FLAVAImageConfig with default values.")
+
+ if multimodal_config_dict is None:
+ multimodal_config_dict = {}
+ logger.info("multimodal_config_dict is None. initializing the FLAVAImageConfig with default values.")
+
+ self.text_config = FLAVATextConfig(**text_config_dict)
+ self.image_config = FLAVAImageConfig(**image_config_dict)
+ self.multimodal_config = FLAVAMultimodalConfig(**multimodal_config_dict)
+ self.projection_dim = projection_dim
+
+ self.hidden_size = hidden_size
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.logit_scale_init_value = logit_scale_init_value
+ self.initializer_factor = 1.0
+ self.ce_ignore_index = ce_ignore_index
+ self.mim_weight = mim_weight
+ self.mlm_weight = mlm_weight
+ self.global_contrastive_weight = global_contrastive_weight
+ self.itm_weight = itm_weight
+ self.mmm_image_weight = mmm_image_weight
+ self.mmm_text_weight = mmm_text_weight
+ self.global_backprop_contrastive = global_backprop_contrastive
+ self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder
+
+ @classmethod
+ def from_configs(
+ cls,
+ text_config: FLAVATextConfig,
+ image_config: FLAVAImageConfig,
+ multimodal_config: FLAVAMultimodalConfig,
+ **kwargs
+ ):
+ r"""
+ Instantiate a [`FLAVAConfig`] (or a derived class) from flava text model configuration, flava image model
+ configuration and flava multimodal model configuration.
+
+ Returns:
+ [`FLAVAConfig`]: An instance of a configuration object
+ """
+
+ return cls(
+ text_config_dict=text_config.to_dict(),
+ image_config_dict=image_config.to_dict(),
+ multimodal_config_dict=multimodal_config.to_dict(),
+ **kwargs,
+ )
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["text_config"] = self.text_config.to_dict()
+ output["image_config"] = self.image_config.to_dict()
+ output["multimodal_config"] = self.multimodal_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
diff --git a/src/transformers/models/flava/convert_dalle_to_flava_codebook.py b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py
new file mode 100644
index 00000000000000..59f54a21eb6150
--- /dev/null
+++ b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py
@@ -0,0 +1,91 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import torch
+
+from transformers import FLAVACodebook, FLAVACodebookConfig
+
+
+def rreplace(s, old, new, occurrence):
+ li = s.rsplit(old, occurrence)
+ return new.join(li)
+
+
+def count_parameters(state_dict):
+ # encoder.embeddings are double copied in original FLAVA
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
+
+
+def upgrade_state_dict(state_dict):
+ upgrade = {}
+
+ for key, value in state_dict.items():
+ if key.endswith(".w"):
+ key = rreplace(key, ".w", ".weight", 1)
+ if key.endswith(".b"):
+ key = rreplace(key, ".b", ".bias", 1)
+
+ upgrade[key] = value.float()
+
+ return upgrade
+
+
+@torch.no_grad()
+def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ from dall_e import Encoder
+
+ encoder = Encoder()
+ if os.path.exists(checkpoint_path):
+ ckpt = torch.load(checkpoint_path)
+ else:
+ ckpt = torch.hub.load_state_dict_from_url(checkpoint_path)
+
+ if isinstance(ckpt, Encoder):
+ ckpt = ckpt.state_dict()
+ encoder.load_state_dict(ckpt)
+
+ if config_path is not None:
+ config = FLAVACodebookConfig.from_pretrained(config_path)
+ else:
+ config = FLAVACodebookConfig()
+
+ hf_model = FLAVACodebook(config).eval()
+ state_dict = encoder.state_dict()
+
+ hf_state_dict = upgrade_state_dict(state_dict)
+ hf_model.load_state_dict(hf_state_dict)
+ hf_state_dict = hf_model.state_dict()
+ hf_count = count_parameters(hf_state_dict)
+ state_dict_count = count_parameters(state_dict)
+
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
+
+ hf_model.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ args = parser.parse_args()
+
+ convert_dalle_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
diff --git a/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
new file mode 100644
index 00000000000000..c1630941ea9765
--- /dev/null
+++ b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import torch
+
+from transformers import FLAVAConfig, FLAVAForPretraining
+
+
+def count_parameters(state_dict):
+ # encoder.embeddings are double copied in original FLAVA
+ return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items())
+
+
+def upgrade_state_dict(state_dict):
+ upgrade = {}
+
+ for key, value in state_dict.items():
+ if "text_encoder.embeddings" in key or "image_encoder.embeddings" in key:
+ continue
+
+ key = key.replace("heads.cmd.mim_head.cls.predictions", "mmm_image_head")
+ key = key.replace("heads.cmd.mlm_head.cls.predictions", "mmm_text_head")
+ key = key.replace("heads.cmd.itm_head.cls", "itm_head")
+ key = key.replace("heads.cmd.itm_head.pooler", "itm_head.pooler")
+ key = key.replace("heads.cmd.clip_head.logit_scale", "flava.logit_scale")
+ key = key.replace("heads.fairseq_mlm.cls.predictions", "mlm_head")
+ key = key.replace("heads.imagenet.mim_head.cls.predictions", "mim_head")
+ key = key.replace("mm_text_projection", "flava.text_to_mm_projection")
+ key = key.replace("mm_image_projection", "flava.image_to_mm_projection")
+ key = key.replace("image_encoder.module", "flava.image_model")
+ key = key.replace("text_encoder.module", "flava.text_model")
+ key = key.replace("mm_encoder.module.encoder.cls_token", "flava.multimodal_model.cls_token")
+ key = key.replace("mm_encoder.module", "flava.multimodal_model")
+ key = key.replace("text_projection", "flava.text_projection")
+ key = key.replace("image_projection", "flava.image_projection")
+
+ upgrade[key] = value.float()
+
+ return upgrade
+
+
+@torch.no_grad()
+def convert_flava_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = FLAVAConfig.from_pretrained(config_path)
+ else:
+ config = FLAVAConfig()
+
+ hf_model = FLAVAForPretraining(config).eval()
+
+ if os.path.exists(checkpoint_path):
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
+ else:
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu")
+
+ hf_state_dict = upgrade_state_dict(state_dict)
+ hf_model.load_state_dict(hf_state_dict)
+ hf_state_dict = hf_model.state_dict()
+ hf_count = count_parameters(hf_state_dict)
+ state_dict_count = count_parameters(state_dict)
+
+ assert torch.allclose(hf_count, state_dict_count, atol=1e-3)
+
+ hf_model.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ args = parser.parse_args()
+
+ convert_flava_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py
new file mode 100644
index 00000000000000..3207b309694324
--- /dev/null
+++ b/src/transformers/models/flava/feature_extraction_flava.py
@@ -0,0 +1,366 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for FLAVA."""
+
+import math
+import random
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+# These values are taken from CLIP
+FLAVA_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
+FLAVA_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]
+FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
+FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
+LOGIT_LAPLACE_EPS: float = 0.1
+
+
+# Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
+class MaskingGenerator:
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]],
+ total_mask_patches: int = 75,
+ mask_group_max_patches: int = None,
+ mask_group_min_patches: Optional[int] = 16,
+ mask_group_min_aspect_ratio: float = 0.3,
+ mask_group_max_aspect_ratio: Optional[float] = None,
+ ):
+ if not isinstance(input_size, tuple):
+ input_size = (input_size,) * 2
+ self.height, self.width = input_size
+
+ self.num_patches = self.height * self.width
+ self.total_mask_patches = total_mask_patches
+
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
+
+ mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
+ self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
+
+ def __repr__(self):
+ repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
+ self.height,
+ self.width,
+ self.mask_group_min_patches,
+ self.mask_group_max_patches,
+ self.total_mask_patches,
+ self.log_aspect_ratio[0],
+ self.log_aspect_ratio[1],
+ )
+ return repr_str
+
+ def get_shape(self):
+ return self.height, self.width
+
+ def _mask(self, mask, max_mask_patches):
+ delta = 0
+ for _attempt in range(10):
+ target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < self.width and h < self.height:
+ top = random.randint(0, self.height - h)
+ left = random.randint(0, self.width - w)
+
+ num_masked = mask[top : top + h, left : left + w].sum()
+ # Overlap
+ if 0 < h * w - num_masked <= max_mask_patches:
+ for i in range(top, top + h):
+ for j in range(left, left + w):
+ if mask[i, j] == 0:
+ mask[i, j] = 1
+ delta += 1
+
+ if delta > 0:
+ break
+ return delta
+
+ def __call__(self):
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
+ mask_count = 0
+ while mask_count < self.total_mask_patches:
+ max_mask_patches = self.total_mask_patches - mask_count
+ max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
+
+ delta = self._mask(mask, max_mask_patches)
+ if delta == 0:
+ break
+ else:
+ mask_count += delta
+
+ return mask
+
+
+class FLAVAFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a FLAVA feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input to a certain `size`.
+ size (`int`, *optional*, defaults to 224):
+ Resize the input to the given size. Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
+ image is padded with 0's and then center cropped.
+ crop_size (`int`, *optional*, defaults to 224):
+ Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with `image_mean` and `image_std`.
+ image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=224,
+ resample=Image.BICUBIC,
+ do_center_crop=True,
+ crop_size=224,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ input_size_patches: int = 14,
+ total_mask_patches: int = 75,
+ mask_group_min_patches: Optional[int] = 16,
+ mask_group_max_patches: int = None,
+ mask_group_min_aspect_ratio: float = 0.3,
+ mask_group_max_aspect_ratio: Optional[float] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
+ self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ @property
+ @lru_cache()
+ def masking_generator(self):
+ return MaskingGenerator(
+ input_size=self.input_size_patches,
+ total_mask_patches=self.total_mask_patches,
+ mask_group_min_patches=self.mask_group_min_patches,
+ mask_group_max_patches=self.mask_group_max_patches,
+ mask_group_min_aspect_ratio=self.mask_group_min_aspect_ratio,
+ mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio,
+ )
+
+ def __call__(
+ self,
+ images: Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+ ],
+ return_masks: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_masks (`bool`, *optional*, defaults to None):
+ If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version.
+
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ """
+ # Input type checking for clearer error
+ if isinstance(images, (list, tuple)) and len(images) != 0:
+ self._ensure_format_supported(images[0])
+ else:
+ self._ensure_format_supported(images)
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None and self.resample is not None:
+ images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [self.center_crop(image, self.crop_size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+ # return as BatchFeature
+ data = {"pixel_values": images}
+
+ if return_masks:
+ masks = [self.masking_generator() for _ in images]
+ data["bool_masked_pos"] = masks
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
+
+
+class FLAVACodebookFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ def __init__(
+ self,
+ do_resize=True,
+ size=112,
+ resample=Image.LANCZOS,
+ do_center_crop=True,
+ crop_size=112,
+ do_map_pixels=True,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.do_map_pixels = do_map_pixels
+ self.crop_size = crop_size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else FLAVA_CODEBOOK_MEAN
+ self.image_std = image_std if image_std is not None else FLAVA_CODEBOOK_STD
+
+ def map_pixels(self, x):
+ return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
+
+ def __call__(
+ self,
+ images: Union[
+ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
+ ],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ """
+ # Input type checking for clearer error
+ if isinstance(images, (list, tuple)) and len(images) != 0:
+ self._ensure_format_supported(images[0])
+ else:
+ self._ensure_format_supported(images)
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + center cropping + normalization)
+ if self.do_resize and self.size is not None and self.resample is not None:
+ images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
+ if self.do_center_crop and self.crop_size is not None:
+ images = [self.center_crop(image, self.crop_size) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+ if self.do_map_pixels:
+ images = [self.map_pixels(image) for image in images]
+ data = {"pixel_values": images}
+
+ # return as BatchFeature
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py
new file mode 100644
index 00000000000000..b500bba58c972d
--- /dev/null
+++ b/src/transformers/models/flava/modeling_flava.py
@@ -0,0 +1,2044 @@
+# coding=utf-8
+# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch FLAVA model."""
+
+import collections
+import math
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+
+from transformers.utils.doc import add_code_sample_docstrings
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_flava import (
+ FLAVACodebookConfig,
+ FLAVAConfig,
+ FLAVAImageConfig,
+ FLAVAMultimodalConfig,
+ FLAVATextConfig,
+)
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "aps/flava-full"
+
+# Codebook docstring
+_CHECKPOINT_FOR_CODEBOOK_DOC = "aps/flava-codebook"
+_FEAT_EXTRACTOR_FOR_DOC = "FLAVAFeatureExtractor"
+_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FLAVAImageConfig"
+_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FLAVATextConfig"
+_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FLAVAMultimodalConfig"
+_TOKENIZER_FOR_DOC = "BertTokenizer"
+_EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
+
+FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "aps/flava-full",
+ # See all flava models at https://huggingface.co/models?filter=flava
+]
+LOGIT_SCALE_CLAMP_MIN = 0
+LOGIT_SCALE_CLAMP_MAX = 4.6052
+
+FLAVAPossibleConfigs = Union[FLAVATextConfig, FLAVAImageConfig, FLAVAMultimodalConfig]
+
+
+def _build_codebook_conv2d(in_size: int, out_size: int, kernel_size: int):
+ return nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
+
+
+@dataclass
+class FLAVAModelOutput(ModelOutput):
+ """
+ Output from FLAVAModel containing embeddings and outputs from individual encoders.
+
+ Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
+
+ Args:
+ image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FLAVAImageModel`].
+ image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FLAVAImageModel`].
+ text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present):
+ The text embeddings which are basically the pooled output of [`FLAVATextModel`].
+ text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
+ The output of the [`FLAVATextModel`].
+ multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
+ The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`].
+ multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
+ The output of the [`FLAVAMultimodalModel`].
+ """
+
+ image_embeddings: Optional[torch.FloatTensor] = None
+ image_output: Optional[BaseModelOutputWithPooling] = None
+ text_embeddings: Optional[torch.FloatTensor] = None
+ text_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
+
+
+@dataclass
+class FLAVALosses(ModelOutput):
+ """Class representing pretraining losses from FLAVA model
+
+ Args:
+ mim(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
+ Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
+ mlm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
+ Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
+ itm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
+ Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
+ masked pairs in FLAVA.
+ global_contrastive(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
+ Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
+ data. This is calculated on unmasked images and texts.
+ mmm_image(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
+ Masked Multimodal Modeling loss's image component calculated on paired image-text data.
+ mmm_text(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
+ Masked Multimodal Modeling loss's text component calculated on paired image-text data.
+ """
+
+ mim: Optional[torch.FloatTensor] = None
+ mlm: Optional[torch.FloatTensor] = None
+ itm: Optional[torch.FloatTensor] = None
+ global_contrastive: Optional[torch.FloatTensor] = None
+ mmm_image: Optional[torch.FloatTensor] = None
+ mmm_text: Optional[torch.FloatTensor] = None
+
+ def all_none(self) -> bool:
+ all_none = True
+ for k, v in self.items():
+ all_none = all_none or v is not None
+ return all_none
+
+
+@dataclass
+class FLAVAForPretrainingOutput(ModelOutput):
+ """
+ Output from FLAVAForPretraining containing embeddings, and outputs from individual encoders.
+
+ Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
+ transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
+ `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
+
+ Args:
+ losses (`FLAVALosses`):
+ Losses for FLAVA Pretraining. Check `FLAVALosses` class description for the information on the keys.
+ image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FLAVAImageModel`].
+ image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FLAVAImageModel`].
+ text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present):
+ The text embeddings which are basically the pooled output of [`FLAVATextModel`].
+ text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
+ The output of the [`FLAVATextModel`].
+ multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
+ The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`].
+ multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
+ The output of the [`FLAVAMultimodalModel`].
+
+ image_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present):
+ The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. Uses `bool_masked_pos`
+ to create masked images.
+ image_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
+ The output of the [`FLAVAImageModel`]. Uses `bool_masked_pos` to create masked images.
+ text_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids_masked` are present):
+ The text embeddings which are basically the pooled output of [`FLAVATextModel`].
+ text_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
+ The output of the [`FLAVATextModel`].
+ multimodal_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present):
+ The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`].
+ multimodal_masked_output(`BaseModelOutputWithPooling`, returned when `input_ids_masked` and `pixel_values` are present):
+ The output of the [`FLAVAMultimodalModel`].
+
+
+ mim_logits:
+ (`torch.FloatTensor` of shape `(batch_size, image_num_patches, image_vocab_size)`, *optional*, returned
+ when `pixel_values` are present and `input_ids_masked` are not): The logits for MIM unimodal loss. Uses
+ `book_masked_pos` to get masked images.
+ mlm_logits:
+ (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)`, *optional*, returned when
+ `input_ids_masked` are present and `pixel_values` are not): The logits for MLM unimodal loss.
+ itm_logits:
+ (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and
+ `pixel_values` are present): The logits for ITM loss. Note that ITM loss is calculated on masked pairs in
+ FLAVA.
+ mmm_image_logits:
+ (`torch.FloatTensor` of shape `(batch_size, image_num_patches, image_vocab_size)`, *optional*, returned
+ when `pixel_values` and `input_ids_masked` are present): The logits for MMM image multimodal loss. Uses
+ `book_masked_pos` to get masked images.
+ mmm_text_logits:
+ (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)`, *optional*, returned when
+ `pixel_values` and `input_ids_masked` are present): The logits for MMM text multimodal loss. Uses
+ `book_masked_pos` to get masked images.
+ contrastive_logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
+ `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
+ scores. This is calculated on unmasked images and texts.
+ contrastive_logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
+ `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
+ texts.
+ """
+
+ losses: FLAVALosses = None
+ image_embeddings: Optional[torch.FloatTensor] = None
+ image_output: Optional[BaseModelOutputWithPooling] = None
+ text_embeddings: Optional[torch.FloatTensor] = None
+ text_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_output: Optional[BaseModelOutputWithPooling] = None
+ image_masked_embeddings: Optional[torch.FloatTensor] = None
+ image_masked_output: Optional[BaseModelOutputWithPooling] = None
+ text_masked_embeddings: Optional[torch.FloatTensor] = None
+ text_masked_output: Optional[BaseModelOutputWithPooling] = None
+ multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
+ multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
+ mim_logits: Optional[torch.FloatTensor] = None
+ mlm_logits: Optional[torch.FloatTensor] = None
+ itm_logits: Optional[torch.FloatTensor] = None
+ contrastive_logits_per_image: Optional[torch.FloatTensor] = None
+ contrastive_logits_per_text: Optional[torch.FloatTensor] = None
+ mmm_image_logits: Optional[torch.FloatTensor] = None
+ mmm_text_logits: Optional[torch.FloatTensor] = None
+
+ def to_tuple(self) -> Tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+# Inspired by
+# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
+# From PyTorch internals
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
+
+
+class FLAVAImageEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+
+ """
+
+ def __init__(self, config: FLAVAImageConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ use_mask_token = use_mask_token or config.mask_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = PatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174
+ """
+
+ npatch = embeddings.shape[1] - 1
+ N = self.position_embeddings.shape[1] - 1
+ if npatch == N and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ h0 = height // self.config.patch_size
+ w0 = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ h0, w0 = h0 + 0.1, w0 + 0.1
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ batch_size, seq_len, _ = embeddings.size()
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # B X H X W = B X HW
+ if bool_masked_pos.dim() == 3:
+ bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
+class PatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+
+ """
+
+ def __init__(
+ self,
+ image_size: int = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ num_channels: int = 3,
+ embed_dim: int = 768,
+ ):
+ super().__init__()
+ image_size = to_2tuple(image_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return x
+
+
+class FLAVATextEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ ):
+ input_shape = input_ids.size()
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class FLAVASelfAttention(nn.Module):
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class FLAVASelfOutput(nn.Module):
+ """
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class FLAVAAttention(nn.Module):
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ self.attention = FLAVASelfAttention(config)
+ self.output = FLAVASelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(
+ hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class FLAVAIntermediate(nn.Module):
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class FLAVAOutput(nn.Module):
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+class FLAVALayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: FLAVAPossibleConfigs) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = FLAVAAttention(config)
+ self.intermediate = FLAVAIntermediate(config)
+ self.output = FLAVAOutput(config)
+
+ # TODO: Check fp32 layer norm possiblity
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class FLAVAEncoder(nn.Module):
+ def __init__(self, config: FLAVAConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([FLAVALayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class FLAVAPooler(nn.Module):
+ def __init__(self, config: FLAVAPossibleConfigs):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+FLAVA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`{config}`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+FLAVA_INPUTS_DOCSTRING_COMMON = r"""
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`FLAVAFeatureExtractor`]. See
+ [`FLAVAFeatureExtractor.__call__`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ interpolate_pos_encoding (`bool`, *optional*):
+ Whether to interpolate the pre-trained position encodings.
+"""
+
+FLAVA_IMAGE_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+"""
+ + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ [What are token type IDs?](../glossary#token-type-ids)
+"""
+FLAVA_TEXT_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+"""
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
+ The concatenated hidden states of unimodal encoders.
+"""
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+FLAVA_MODEL_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+"""
+ + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+ + r"""
+ skip_multimodal_encoder (*bool*, *optional*):
+ Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
+"""
+)
+
+
+FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
+ r"""
+ Args:
+ input_ids_masked (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
+ to be used with MLM. Indices can be obtained using [`BertTokenizer`] along with
+ [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
+
+"""
+ + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
+ + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
+ + r"""
+ image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*):
+ Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
+ in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ skip_unmasked_multimodal_encoder (*bool*, *optional*):
+ Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
+ multimodal embeddings or outputs as of now.
+
+ mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
+ Labels for computing the left-to-right language and multimodal masked modeling loss (next word
+ prediction). Indices should be in `[-100, 0, ..., text_config.vocab_size]` (see `input_ids` docstring)
+ Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with
+ labels n `[0, ..., text_config.vocab_size]`
+
+ mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
+ Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0,
+ ..., image_config.vocab_size]`. Tokens with indices set to `-100` are ignored (masked), the loss is
+ only computed for the tokens with labels n `[0, ..., image_config.vocab_size]`. See [`FLAVACodebook`]
+ to understand how to generate mim_labels.
+
+ itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
+ Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they
+ match. The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
+"""
+ + FLAVA_INPUTS_DOCSTRING_COMMON
+)
+
+
+class FLAVAPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = FLAVAConfig
+ base_model_prefix = "flava"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module: FLAVAEncoder, value: bool = False) -> None:
+ if isinstance(module, FLAVAEncoder):
+ module.gradient_checkpointing = value
+
+
+@add_start_docstrings(
+ "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FLAVAImageConfig"),
+)
+class FLAVAImageModel(FLAVAPreTrainedModel):
+ config_class = FLAVAImageConfig
+ base_model_prefix = "flava.image_model"
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: FLAVAImageConfig, add_pooling_layer: bool = True):
+ super().__init__(config)
+
+ self.config = config
+
+ self.embeddings = FLAVAImageEmbeddings(config)
+ self.encoder = FLAVAEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FLAVAPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.embeddings.patch_embeddings
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.embeddings.patch_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FLAVATextConfig"),
+)
+class FLAVATextModel(FLAVAPreTrainedModel):
+ config_class = FLAVATextConfig
+ base_model_prefix = "flava.text_model"
+
+ def __init__(self, config: FLAVATextConfig, add_pooling_layer: bool = True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = FLAVATextEmbeddings(config)
+ self.encoder = FLAVAEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FLAVAPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def get_input_embeddings(self) -> PatchEmbeddings:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value: nn.Module):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, input_shape, input_ids.device
+ )
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FLAVAMultimodalConfig"),
+)
+class FLAVAMultimodalModel(FLAVAPreTrainedModel):
+ config_class = FLAVAMultimodalConfig
+ base_model_prefix = "flava.multimodal_model"
+ main_input_name = "hidden_states"
+
+ def __init__(self, config: FLAVAMultimodalConfig, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+ self.use_cls_token = self.config.use_cls_token
+ if self.use_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+
+ self.encoder = FLAVAEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = FLAVAPooler(config) if add_pooling_layer else None
+
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
+ )
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
+ )
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, seq_length, _ = hidden_states.size()
+
+ if self.use_cls_token:
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
+ seq_length += 1
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states.device
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAVA_START_DOCSTRING.format(config="FLAVAConfig"),
+)
+class FLAVAModel(FLAVAPreTrainedModel):
+ config_class = FLAVAConfig
+
+ def __init__(self, config: FLAVAConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, FLAVATextConfig):
+ raise ValueError(
+ f"config.text_config is expected to be of type FLAVATextConfig but is of type {type(config.text_config)}."
+ )
+
+ if not isinstance(config.image_config, FLAVAImageConfig):
+ raise ValueError(
+ f"config.image_config is expected to be of type FLAVAImageConfig but is of type {type(config.image_config)}."
+ )
+
+ if not isinstance(config.multimodal_config, FLAVAMultimodalConfig):
+ raise ValueError(
+ "config.multimodal_config is expected to be of type FLAVAMultimodalConfig but "
+ + f"is of type {type(config.multimodal_config)}."
+ )
+
+ text_config = config.text_config
+ image_config = config.image_config
+ multimodal_config = config.multimodal_config
+
+ self.projection_dim = config.projection_dim
+ self.text_hidden_size = text_config.hidden_size
+ self.image_hidden_size = image_config.hidden_size
+ self.mm_hidden_size = multimodal_config.hidden_size
+
+ self.text_model = FLAVATextModel(text_config)
+ self.image_model = FLAVAImageModel(image_config)
+ self.multimodal_model = FLAVAMultimodalModel(multimodal_config)
+
+ self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
+ self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
+ self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
+
+ self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
+ self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`FLAVATextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import FLAVAProcessor, FLAVAModel
+
+ >>> model = FLAVAModel.from_pretrained("{0}")
+ >>> processor = FLAVAProcessor.from_pretrained("{0}")
+
+ >>> inputs = processor(
+ ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
+ ... )
+ >>> text_features = model.get_text_features(**inputs)
+ ```""".format(
+ _CHECKPOINT_FOR_DOC
+ )
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = text_outputs[0] # last_hidden_state
+ text_features = self.text_projection(pooled_output)
+
+ return text_features
+
+ @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`FLAVAImageModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FLAVAProcessor, FLAVAModel
+
+ >>> model = FLAVAModel.from_pretrained("{0}")
+ >>> processor = FLAVAProcessor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> image_features = model.get_image_features(**inputs)
+ ```""".format(
+ _CHECKPOINT_FOR_DOC
+ )
+ image_outputs = self.image_model(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = image_outputs[0] # last_hidden_state
+ image_features = self.image_projection(pooled_output)
+
+ return image_features
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
+ )
+ @replace_return_docstrings(output_type=FLAVAModelOutput, config_class=FLAVAConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ skip_multimodal_encoder: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: bool = True,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, FLAVAOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FLAVAProcessor, FLAVAModel
+
+ >>> model = FLAVAModel.from_pretrained("aps/flava-full")
+ >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
+
+ >>> outputs = model(**inputs)
+ >>> logits_per_image = outputs.contrastive_logits_per_image # this is the image-text similarity score
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+ assert output_hidden_states is True, "FLAVA model requires hidden states to work."
+ image_embeddings = None
+ image_states = None
+ image_mm_projection = None
+ image_output = None
+ if pixel_values is not None:
+ image_output = self.image_model(
+ pixel_values=pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ attention_mask=image_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ image_embeddings, image_states = image_output[0], image_output[2]
+ # Note that these states don't use final layernorm in the transformer model
+ image_mm_projection = self.image_to_mm_projection(image_states[-1])
+
+ text_embeddings = None
+ text_states = None
+ text_mm_projection = None
+ text_output = None
+ if input_ids is not None:
+ text_output = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ text_embeddings, text_states = text_output[0], text_output[2]
+ # Note that these states don't use final layernorm in the transformer model
+ text_mm_projection = self.text_to_mm_projection(text_states[-1])
+
+ multimodal_embeddings = None
+ multimodal_output = None
+ if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
+ multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
+ multimodal_output = self.multimodal_model(multimodal_input)
+ multimodal_embeddings = multimodal_output[0]
+
+ return FLAVAModelOutput(
+ image_embeddings=image_embeddings,
+ image_output=image_output,
+ text_embeddings=text_embeddings,
+ text_output=text_output,
+ multimodal_embeddings=multimodal_embeddings,
+ multimodal_output=multimodal_output,
+ )
+
+
+class FLAVACodebookBlock(nn.Module):
+ def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
+ super().__init__()
+
+ n_hid = out_size // 4
+ self.post_gain = 1 / (num_layers**2)
+
+ if in_size != out_size:
+ self.id_path = _build_codebook_conv2d(in_size, out_size, kernel_size=1)
+ else:
+ self.id_path = nn.Identity()
+
+ self.res_path = nn.Sequential(
+ OrderedDict(
+ [
+ ("relu_1", nn.ReLU()),
+ ("conv_1", _build_codebook_conv2d(in_size, n_hid, kernel_size=3)),
+ ("relu_2", nn.ReLU()),
+ ("conv_2", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)),
+ ("relu_3", nn.ReLU()),
+ ("conv_3", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)),
+ ("relu_4", nn.ReLU()),
+ ("conv_4", _build_codebook_conv2d(n_hid, out_size, kernel_size=1)),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.id_path(x) + self.post_gain * self.res_path(x)
+
+
+# Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
+@add_start_docstrings(
+ """
+ The FLAVA's codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used to
+ generate image tokens for an image based on DALL-E's vocab. To be used to generate labels for MIM. Use
+ `get_codebook_indices` to get image tokens for an image.
+ """,
+ FLAVA_START_DOCSTRING.format(config="FLAVACodebookConfig"),
+)
+class FLAVACodebook(FLAVAPreTrainedModel):
+ base_model_prefix = ""
+ config_class = FLAVACodebookConfig
+ main_input_name = "pixel_values"
+
+ def __init__(
+ self,
+ config: FLAVACodebookConfig,
+ **kwargs: Any,
+ ):
+ super().__init__(config)
+
+ self.config = config
+ self.num_groups = config.num_groups
+ self.input_channels = config.input_channels
+ self.num_blocks_per_group = config.num_blocks_per_group
+ self.hidden_size = config.hidden_size
+ self.vocab_size = config.vocab_size
+
+ num_layers = self.num_groups * self.num_blocks_per_group
+ output_conv = _build_codebook_conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1)
+
+ self.blocks = nn.Sequential(
+ OrderedDict(
+ [
+ ("input", _build_codebook_conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7)),
+ (
+ "group_1",
+ self._create_group(
+ num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 1 * self.hidden_size
+ ),
+ ),
+ (
+ "group_2",
+ self._create_group(
+ num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 2 * self.hidden_size
+ ),
+ ),
+ (
+ "group_3",
+ self._create_group(
+ num_layers, self.num_blocks_per_group, 2 * self.hidden_size, 4 * self.hidden_size
+ ),
+ ),
+ (
+ "group_4",
+ self._create_group(
+ num_layers,
+ self.num_blocks_per_group,
+ 4 * self.hidden_size,
+ 8 * self.hidden_size,
+ use_pool=False,
+ ),
+ ),
+ (
+ "output",
+ nn.Sequential(OrderedDict([("relu", nn.ReLU()), ("conv", output_conv)])),
+ ),
+ ]
+ )
+ )
+ self.post_init()
+
+ if self.config.freeze:
+ self._freeze()
+
+ def _freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def _create_group(
+ self,
+ num_layers: int,
+ num_blocks_per_group: int,
+ in_size: int,
+ hidden_size: int,
+ use_pool: bool = True,
+ ):
+ blocks = OrderedDict()
+ for i in range(num_blocks_per_group):
+ if i == 0:
+ blocks[f"block_{i+1}"] = FLAVACodebookBlock(in_size, hidden_size, num_layers)
+ else:
+ blocks[f"block_{i+1}"] = FLAVACodebookBlock(hidden_size, hidden_size, num_layers)
+
+ if use_pool:
+ blocks["pool"] = nn.MaxPool2d(kernel_size=2)
+
+ return nn.Sequential(blocks)
+
+ def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See
+ [`FLAVACodebookFeatureExtractor.__call__`] for details.
+
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook
+
+ >>> model = FLAVAModel.from_pretrained("{0}")
+ >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = feature_extractor(image, return_mask=True, return_tensors="pt")
+
+ >>> outputs = model.get_codebook_indices(**inputs)
+ ```
+ """.format(
+ _CHECKPOINT_FOR_CODEBOOK_DOC
+ )
+ z_logits = self.blocks(pixel_values)
+ return torch.argmax(z_logits, axis=1)
+
+ def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ z_logits = self.blocks(pixel_values)
+ return nn.Softmax(dim=1)(z_logits)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See
+ [`FLAVACodebookFeatureExtractor.__call__`] for details.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook
+
+ >>> model = FLAVAModel.from_pretrained("{0}")
+ >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = feature_extractor([image], return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> print(outputs.shape)
+ (1, 196)
+ ```
+ """.format(
+ _CHECKPOINT_FOR_CODEBOOK_DOC
+ )
+ if len(pixel_values.shape) != 4:
+ raise ValueError(f"input shape {pixel_values.shape} is not 4d")
+ if pixel_values.shape[1] != self.input_channels:
+ raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
+ return self.blocks(pixel_values)
+
+
+class FLAVAPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class FLAVAMaskedPredictionHead(nn.Module):
+ def __init__(self, config, weight=None):
+ super().__init__()
+ self.config = config
+ self.transform = FLAVAPredictionHeadTransform(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ if weight is not None:
+ self.decoder.weight = weight
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, x):
+ x = self.transform(x)
+ x = self.decoder(x)
+ return x
+
+
+class FLAVAITMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pooler = FLAVAPooler(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, x):
+ x = self.pooler(x)
+ x = self.seq_relationship(x)
+ return x
+
+
+class FLAVAGlobalContrastiveHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.global_backprop_contrastive = config.global_backprop_contrastive
+
+ def forward(self, image_embeddings, text_embeddings, logit_scale):
+ temperature = torch.exp(logit_scale)
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
+ image_embeddings_all = [image_embeddings]
+ text_embeddings_all = [text_embeddings]
+ else:
+ local_batch_size = image_embeddings.size(0)
+ world_size = torch.distributed.get_world_size()
+
+ if self.global_backprop_contrastive:
+ image_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(image_embeddings)
+ text_embeddings_all = torch.distributed.nn.functional.all_gather_with_backprop(text_embeddings)
+ else:
+ image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
+ text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
+ torch.distributed.all_gather(image_embeddings_all, image_embeddings)
+ torch.distributed.all_gather(text_embeddings_all, text_embeddings)
+
+ labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
+ local_batch_size, device=image_embeddings.device
+ )
+
+ image_embeddings_all = torch.cat(image_embeddings_all)
+ text_embeddings_all = torch.cat(text_embeddings_all)
+
+ logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
+ logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
+
+ return logits_per_image, logits_per_text, labels
+
+
+@add_start_docstrings(
+ """
+ The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
+ """,
+ FLAVA_START_DOCSTRING.format(config="FLAVAConfig"),
+)
+class FLAVAForPretraining(FLAVAPreTrainedModel):
+ def __init__(self, config: FLAVAConfig):
+ super().__init__(config)
+ self.flava = FLAVAModel(config)
+
+ # Levarage text and image encoder configs to create the masked
+ # head since it has the right vocab
+ self.mim_head = FLAVAMaskedPredictionHead(config.image_config)
+ self.mlm_head = FLAVAMaskedPredictionHead(config.text_config)
+ self.itm_head = FLAVAITMHead(config)
+ self.mmm_image_head = FLAVAMaskedPredictionHead(config.image_config)
+ self.mmm_text_head = FLAVAMaskedPredictionHead(config.text_config)
+ self.global_contrastive_head = FLAVAGlobalContrastiveHead(config)
+
+ self.image_vocab_size = config.image_config.vocab_size
+ self.text_vocab_size = config.text_config.vocab_size
+ self.mlm_weight = config.mlm_weight
+ self.mim_weight = config.mim_weight
+ self.global_contrastive_weight = config.global_contrastive_weight
+ self.ce_ignore_index = config.ce_ignore_index
+ self.itm_weight = config.itm_weight
+ self.mmm_image_weight = config.mmm_image_weight
+ self.mmm_text_weight = config.mmm_text_weight
+ self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
+
+ def _resize_to_2d(self, x: torch.Tensor):
+ if x.dim() > 2:
+ x = x.view(x.size(0), -1)
+ return x
+
+ @add_start_docstrings_to_model_forward(
+ FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
+ )
+ @replace_return_docstrings(output_type=FLAVAForPretrainingOutput, config_class=FLAVAConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ input_ids_masked: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ image_attention_mask: Optional[torch.Tensor] = None,
+ # TODO: Move to config with better name
+ skip_unmasked_multimodal_encoder: bool = None,
+ mlm_labels: Optional[torch.Tensor] = None,
+ mim_labels: Optional[torch.Tensor] = None,
+ itm_labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: bool = True,
+ return_dict: Optional[bool] = None,
+ ):
+ """
+ Examples:
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook, FLAVAForPretraining, FLAVAProcessor
+
+ >>> codebook = FLAVACodebook.from_pretrained("aps/flava-codebook")
+ >>> codebook_feature_extractor = FLAVACodebookFeatureExtractor.from_pretrained("aps/flava-codebook")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = codebook_feature_extractor([image], return_tensors="pt")
+
+ >>> mim_labels = codebook.get_codebook_indices(**inputs)
+
+ >>> model = FLAVAForPretraining.from_pretrained("aps/flava-full")
+ >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full")
+
+ >>> text = ["a photo of a cat"]
+
+ >>> inputs = processor(
+ ... images=[image], text=text, return_masks=True, padding=True, max_length=77, return_tensors="pt"
+ ... )
+ >>> inputs["mim_labels"] = mim_labels
+
+ >>> output = model(**inputs)
+ ```
+
+ Return:
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ skip_unmasked_multimodal_encoder = (
+ skip_unmasked_multimodal_encoder
+ if skip_unmasked_multimodal_encoder is not None
+ else self.skip_unmasked_multimodal_encoder
+ )
+
+ flava_output = self.flava(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ image_attention_mask=image_attention_mask,
+ # Don't need unmasked multimodal embedding for anything so skip it
+ # NOTE: ITM uses masked version
+ skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ # Pass true to have deterministic outputs
+ return_dict=True,
+ )
+
+ flava_masked_output = self.flava(
+ input_ids=input_ids_masked,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ image_attention_mask=image_attention_mask,
+ bool_masked_pos=bool_masked_pos,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ pos_mask = None
+
+ image_embeddings = flava_output.image_embeddings
+ text_embeddings = flava_output.text_embeddings
+ image_masked_embeddings = flava_masked_output.image_embeddings
+ text_masked_embeddings = flava_masked_output.text_embeddings
+ multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
+ mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
+ mim_logits = (
+ mlm_logits
+ ) = mmm_text_logits = mmm_image_logits = itm_logits = logits_per_image = logits_per_text = None
+
+ # Unimodal MIM Loss
+ # If multimodal embeddings are present, we will calculate MMM loss
+ if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
+ sequence_for_image = image_masked_embeddings
+ if mim_labels is not None:
+ mim_labels = self._resize_to_2d(mim_labels)
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
+
+ sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
+ mim_labels_filtered = mim_labels[masked_tokens]
+ sequence_for_image = sequence_for_image[masked_tokens, :]
+ mim_logits = self.mim_head(sequence_for_image)
+ mim_loss = nn.functional.cross_entropy(
+ mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
+ )
+ mim_loss *= self.mim_weight
+ else:
+ mim_logits = self.mim_head(sequence_for_image)
+
+ # Unimodal MLM Loss
+ if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
+ sequence_for_text = text_masked_embeddings
+ if mlm_labels is not None:
+ mlm_labels = self._resize_to_2d(mlm_labels)
+ sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
+ mlm_labels_filtered = mlm_labels[masked_tokens]
+ sequence_for_text = sequence_for_text[masked_tokens, :]
+ mlm_logits = self.mlm_head(sequence_for_text)
+ mlm_loss = nn.functional.cross_entropy(
+ mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
+ )
+ mlm_loss *= self.mlm_weight
+ else:
+ mlm_logits = self.mlm_head(sequence_for_text)
+
+ # ITM Loss
+ if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
+ itm_logits = self.itm_head(multimodal_masked_embeddings)
+
+ if itm_labels is not None:
+ pos_pairs = itm_labels.ne(0)
+ pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
+ itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
+ itm_loss *= self.itm_weight
+
+ if multimodal_masked_embeddings is not None:
+ multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
+
+ if mlm_labels is not None:
+ mlm_labels = mlm_labels[pos_mask]
+
+ if mim_labels is not None:
+ mim_labels = mim_labels[pos_mask]
+
+ # MMM Image Loss
+ if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
+ sequence_for_image = multimodal_masked_embeddings
+ end_index = image_masked_embeddings.size(1) - 1
+ sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
+
+ if pos_mask is not None:
+ sequence_for_image = sequence_for_image[pos_mask]
+ if mim_labels is not None:
+ mim_labels = self._resize_to_2d(mim_labels)
+ bool_masked_pos = self._resize_to_2d(bool_masked_pos)
+ mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
+
+ masked_tokens = mim_labels.ne(self.ce_ignore_index)
+ mim_labels_filtered = mim_labels[masked_tokens]
+ sequence_for_image = sequence_for_image[masked_tokens, :]
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
+ mmm_image_loss = nn.functional.cross_entropy(
+ mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
+ )
+ mmm_image_loss *= self.mmm_image_weight
+ else:
+ mmm_image_logits = self.mmm_image_head(sequence_for_image)
+
+ # MMM Text Loss
+ if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
+ sequence_for_text = multimodal_masked_embeddings
+ sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
+ if pos_mask is not None:
+ sequence_for_text = sequence_for_text[pos_mask]
+
+ if mlm_labels is not None:
+ mlm_labels = self._resize_to_2d(mlm_labels)
+ masked_tokens = mlm_labels.ne(self.ce_ignore_index)
+ mlm_labels_filtered = mlm_labels[masked_tokens]
+ sequence_for_text = sequence_for_text[masked_tokens, :]
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
+ mmm_text_loss = nn.functional.cross_entropy(
+ mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
+ )
+ mmm_text_loss *= self.mmm_text_weight
+ else:
+ mmm_text_logits = self.mmm_text_head(sequence_for_text)
+
+ # Global Contrastive Loss
+ if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
+ text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
+ text_embedding = nn.functional.normalize(text_embedding, dim=-1)
+
+ image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
+ image_embedding = nn.functional.normalize(image_embedding, dim=-1)
+
+ self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
+
+ logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
+ image_embedding, text_embedding, self.flava.logit_scale
+ )
+
+ # Apply ITM negative mask if any
+ if pos_mask is not None:
+ logits_per_image = logits_per_image[pos_mask]
+ logits_per_text = logits_per_text[pos_mask]
+ gc_labels = gc_labels[pos_mask]
+
+ gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
+ gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
+ gc_loss = (gc_loss_image + gc_loss_text) / 2
+ gc_loss *= self.global_contrastive_weight
+
+ flava_losses = FLAVALosses(
+ mim=mim_loss,
+ mlm=mlm_loss,
+ itm=itm_loss,
+ global_contrastive=gc_loss,
+ mmm_image=mmm_image_loss,
+ mmm_text=mmm_text_loss,
+ )
+
+ if not return_dict:
+ output = (
+ image_embeddings,
+ flava_output.image_output,
+ text_embeddings,
+ flava_output.text_output,
+ flava_output.multimodal_embeddings,
+ flava_output.multimodal_output,
+ image_masked_embeddings,
+ flava_masked_output.image_output,
+ text_masked_embeddings,
+ flava_masked_output.text_output,
+ multimodal_masked_embeddings,
+ flava_masked_output.multimodal_output,
+ mim_logits,
+ mlm_logits,
+ itm_logits,
+ logits_per_image,
+ logits_per_image,
+ mmm_image_logits,
+ mmm_text_logits,
+ )
+ if flava_losses.all_none():
+ return (flava_losses,) + output
+ else:
+ return output
+
+ return FLAVAForPretrainingOutput(
+ losses=flava_losses,
+ image_embeddings=image_embeddings,
+ image_output=flava_output.image_output,
+ text_embeddings=text_embeddings,
+ text_output=flava_output.text_output,
+ multimodal_embeddings=flava_output.multimodal_embeddings,
+ multimodal_output=flava_output.multimodal_output,
+ image_masked_embeddings=image_masked_embeddings,
+ image_masked_output=flava_masked_output.image_output,
+ text_masked_embeddings=text_masked_embeddings,
+ text_masked_output=flava_masked_output.text_output,
+ multimodal_masked_embeddings=multimodal_masked_embeddings,
+ multimodal_masked_output=flava_masked_output.multimodal_output,
+ mim_logits=mim_logits,
+ mlm_logits=mlm_logits,
+ itm_logits=itm_logits,
+ contrastive_logits_per_image=logits_per_image,
+ contrastive_logits_per_text=logits_per_text,
+ mmm_image_logits=mmm_image_logits,
+ mmm_text_logits=mmm_text_logits,
+ )
diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py
new file mode 100644
index 00000000000000..e75074eed81298
--- /dev/null
+++ b/src/transformers/models/flava/processing_flava.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Image/Text processor class for FLAVA
+"""
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from PIL import Image
+
+from transformers.data.data_collator import DataCollatorForWholeWordMask, tolist
+
+from ...processing_utils import ProcessorMixin
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from ...utils import TensorType
+
+
+class FLAVAProcessor(ProcessorMixin):
+ r"""
+ Constructs a FLAVA processor which wraps a FLAVA feature extractor and a FLAVA tokenizer into a single processor.
+
+ [`FLAVAProcessor`] offers all the functionalities of [`FLAVAFeatureExtractor`] and [`FLAVATokenizerFast`]. See the
+ [`~FLAVAProcessor.__call__`] and [`~FLAVAProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor ([`FLAVAFeatureExtractor`]):
+ The feature extractor is a required input.
+ tokenizer ([`FLAVATokenizerFast`]):
+ The tokenizer is a required input.
+ """
+ feature_extractor_class = "FLAVAFeatureExtractor"
+ tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
+
+ def __init__(self, feature_extractor, tokenizer, mlm_probability=0.15):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self.text_masker = DataCollatorForWholeWordMask(tokenizer, mlm=True, mlm_probability=mlm_probability)
+
+ def __call__(
+ self,
+ images: Optional[
+ Union[
+ Image.Image,
+ np.ndarray,
+ "torch.Tensor",
+ List[Image.Image],
+ List[np.ndarray],
+ List["torch.Tensor"], # noqa
+ ]
+ ] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_masks: Optional[bool] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ):
+ """
+ This method uses [`FLAVAFeatureExtractor.__call__`] method to prepare image(s) for the model, and
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
+
+ Please refer to the docstring of the above two methods for more information. Other special args are mentioned
+ below:
+
+ Args:
+ return_mask (`bool`, *optional*, defaults to None):
+ If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version and
+ `input_ids_masked` and `mlm_labels` for MLM.
+ """
+
+ if text is None and images is None:
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
+
+ if text is not None:
+ encoding = self.tokenizer(
+ text=text,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask or return_masks,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+ if images is not None:
+ image_features = self.feature_extractor(
+ images, return_masks=return_masks, return_tensors=return_tensors, **kwargs
+ )
+
+ if return_masks and text is not None:
+ batch_masked = self.text_masker(tolist(encoding["input_ids"]), return_tensors=return_tensors)
+ encoding["input_ids_masked"] = batch_masked["input_ids"]
+ encoding["mlm_labels"] = batch_masked["labels"]
+ encoding.pop("special_tokens_mask")
+
+ if text is not None and images is not None:
+ encoding.update(image_features)
+ return encoding
+ elif text is not None:
+ return encoding
+ else:
+ return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/tests/flava/__init__.py b/tests/flava/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/flava/test_feature_extraction_flava.py b/tests/flava/test_feature_extraction_flava.py
new file mode 100644
index 00000000000000..f6d37cecf7696c
--- /dev/null
+++ b/tests/flava/test_feature_extraction_flava.py
@@ -0,0 +1,414 @@
+# coding=utf-8
+# Copyright 2021 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor
+ from transformers.models.flava.feature_extraction_flava import (
+ FLAVA_CODEBOOK_MEAN,
+ FLAVA_CODEBOOK_STD,
+ FLAVA_IMAGE_MEAN,
+ FLAVA_IMAGE_STD,
+ )
+
+
+# TODO(aps): Add joint feature extractor useful for pretraining
+class FLAVAFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=224,
+ do_center_crop=True,
+ crop_size=224,
+ resample=Image.BICUBIC,
+ do_normalize=True,
+ image_mean=FLAVA_IMAGE_MEAN,
+ image_std=FLAVA_IMAGE_STD,
+ input_size_patches=14,
+ total_mask_patches=75,
+ mask_group_max_patches=None,
+ mask_group_min_patches=16,
+ mask_group_min_aspect_ratio=0.3,
+ mask_group_max_aspect_ratio=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.do_resize = do_resize
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.size = size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.input_size_patches = input_size_patches
+ self.total_mask_patches = total_mask_patches
+ self.mask_group_max_patches = mask_group_max_patches
+ self.mask_group_min_patches = mask_group_min_patches
+ self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
+ self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "resample": self.resample,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "input_size_patches": self.input_size_patches,
+ "total_mask_patches": self.total_mask_patches,
+ "mask_group_max_patches": self.mask_group_max_patches,
+ "mask_group_min_patches": self.mask_group_min_patches,
+ "mask_group_min_aspect_ratio": self.mask_group_min_aspect_ratio,
+ "mask_group_max_aspect_ratio": self.mask_group_min_aspect_ratio,
+ }
+
+ def get_expected_image_size(self):
+ return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
+
+ def get_expected_mask_size(self):
+ return (
+ (self.input_size_patches, self.input_size_patches)
+ if not isinstance(self.input_size_patches, tuple)
+ else self.input_size_patches
+ )
+
+
+@require_torch
+@require_vision
+class FLAVAFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = FLAVAFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = FLAVAFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "resample"))
+ self.assertTrue(hasattr(feature_extractor, "crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+ self.assertTrue(hasattr(feature_extractor, "masking_generator"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ # Test no bool masked pos
+ self.assertFalse("bool_masked_pos" in encoded_images)
+
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def _test_call_framework(self, instance_class, prepare_kwargs):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs)
+ for image in image_inputs:
+ self.assertIsInstance(image, instance_class)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ # Test masking
+ encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt")
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_mask_size()
+ self.assertEqual(
+ encoded_images.bool_masked_pos.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
+
+ def test_call_pytorch(self):
+ self._test_call_framework(
+ torch.Tensor,
+ prepare_kwargs={
+ "torchify": True,
+ },
+ )
+
+ def test_masking(self):
+ # Initialize feature_extractor
+ random.seed(1234)
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_masks=True, return_tensors="pt")
+ self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
+
+
+class FLAVACodebookFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=112,
+ do_center_crop=True,
+ crop_size=112,
+ resample=Image.LANCZOS,
+ do_normalize=True,
+ image_mean=FLAVA_CODEBOOK_MEAN,
+ image_std=FLAVA_CODEBOOK_STD,
+ do_map_pixels=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.do_resize = do_resize
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.size = size
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_map_pixels = do_map_pixels
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_center_crop": self.do_center_crop,
+ "crop_size": self.crop_size,
+ "resample": self.resample,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_map_pixels": self.do_map_pixels,
+ }
+
+ def get_expected_image_size(self):
+ return (self.size, self.size) if not isinstance(self.size, tuple) else self.size
+
+
+@require_torch
+@require_vision
+class FLAVACodebookFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = FLAVACodebookFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = FLAVACodebookFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "resample"))
+ self.assertTrue(hasattr(feature_extractor, "crop_size"))
+ self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+
+ def _test_call_framework(self, instance_class, prepare_kwargs):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs)
+ for image in image_inputs:
+ self.assertIsInstance(image, instance_class)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt")
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.pixel_values.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ expected_height, expected_width = self.feature_extract_tester.get_expected_image_size()
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
+
+ def test_call_pytorch(self):
+ self._test_call_framework(
+ torch.Tensor,
+ prepare_kwargs={
+ "torchify": True,
+ },
+ )
diff --git a/tests/flava/test_modeling_flava.py b/tests/flava/test_modeling_flava.py
new file mode 100644
index 00000000000000..e844c93da401e1
--- /dev/null
+++ b/tests/flava/test_modeling_flava.py
@@ -0,0 +1,824 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch FLAVA model. """
+
+
+import inspect
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+import requests
+from transformers import FLAVAConfig, FLAVAImageConfig, FLAVAMultimodalConfig, FLAVATextConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import is_torch_available, is_vision_available
+
+from ..test_configuration_common import ConfigTester
+from ..test_modeling_common import (
+ ModelTesterMixin,
+ _config_zero_init,
+ floats_tensor,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import FLAVAForPretraining, FLAVAImageModel, FLAVAModel, FLAVAMultimodalModel, FLAVATextModel
+ from transformers.models.flava.modeling_flava import FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FLAVAProcessor
+
+
+class FLAVAImageModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ qkv_bias=True,
+ mask_token=True,
+ vocab_size=8192,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.mask_token = mask_token
+ self.vocab_size = vocab_size
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+ num_patches = self.image_size // self.patch_size
+ bool_masked_pos = (torch.rand((self.batch_size, num_patches, num_patches)) < 0.9).long()
+ config = self.get_config()
+ return config, pixel_values, bool_masked_pos
+
+ def get_config(self):
+ return FLAVAImageConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ qkv_bias=self.qkv_bias,
+ mask_token=self.mask_token,
+ vocab_size=self.vocab_size,
+ )
+
+ def create_and_check_model(self, config, pixel_values, bool_masked_pos):
+ model = FLAVAImageModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(pixel_values, bool_masked_pos)
+ # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.image_size, self.image_size)
+ patch_size = (self.patch_size, self.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, bool_masked_pos = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values, "bool_masked_pos": bool_masked_pos}
+ return config, inputs_dict
+
+
+@require_torch
+class FLAVAImageModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as FLAVA does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (FLAVAImageModel,) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = FLAVAImageModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FLAVAImageConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ # in FLAVA, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_len = num_patches + 1
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # FLAVA has a different seq_length
+ image_size = (self.model_tester.image_size, self.model_tester.image_size)
+ patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ seq_length = num_patches + 1
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ # skip this test as FLAVAImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FLAVAImageModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FLAVAImageModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FLAVATextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ vocab_size=30522,
+ type_vocab_size=2,
+ max_position_embeddings=512,
+ position_embedding_type="absolute",
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ qkv_bias=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.position_embedding_type = position_embedding_type
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.pad_token_id = pad_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ token_type_ids = None
+
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, input_mask
+
+ def get_config(self):
+ return FLAVATextConfig(
+ vocab_size=self.vocab_size,
+ type_vocab_size=self.type_vocab_size,
+ max_position_embeddings=self.max_position_embeddings,
+ position_embedding_type=self.position_embedding_type,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ pad_token_id=self.pad_token_id,
+ qkv_bias=self.qkv_bias,
+ )
+
+ def create_and_check_model(self, config, input_ids, token_type_ids, input_mask):
+ model = FLAVATextModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, input_mask = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FLAVATextModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FLAVATextModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FLAVATextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=FLAVATextConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FLAVATextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FLAVATextModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FLAVATextModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FLAVAMultimodalModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=12,
+ seq_length=44,
+ use_input_mask=True,
+ hidden_size=768,
+ num_hidden_layers=6,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ qkv_bias=True,
+ use_cls_token=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.use_input_mask = use_input_mask
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.use_cls_token = use_cls_token
+
+ def prepare_config_and_inputs(self):
+ hidden_states = floats_tensor([self.batch_size, self.seq_length - 1, self.hidden_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ if input_mask is not None:
+ batch_size, seq_length = input_mask.shape
+ rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
+ for batch_idx, start_index in enumerate(rnd_start_indices):
+ input_mask[batch_idx, :start_index] = 1
+ input_mask[batch_idx, start_index:] = 0
+
+ config = self.get_config()
+
+ return config, hidden_states, input_mask
+
+ def get_config(self):
+ return FLAVAMultimodalConfig(
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ qkv_bias=self.qkv_bias,
+ use_cls_token=self.use_cls_token,
+ )
+
+ def create_and_check_model(self, config, hidden_states, input_mask):
+ model = FLAVAMultimodalModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ result = model(hidden_states, attention_mask=input_mask)
+ result = model(hidden_states)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+ self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, hidden_states, input_mask = config_and_inputs
+ inputs_dict = {"hidden_states": hidden_states, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class FLAVAMultimodalModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (FLAVAMultimodalModel,) if is_torch_available() else ()
+ test_pruning = False
+ test_head_masking = False
+ test_resize_embeddings = False
+ test_torchscript = False
+
+ def setUp(self):
+ self.model_tester = FLAVAMultimodalModelTester(self)
+ self.config_tester = ConfigTester(
+ self, config_class=FLAVAMultimodalConfig, has_text_modality=False, hidden_size=37
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["hidden_states"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model_common_attributes(self):
+ # No embedding in multimodal model
+ pass
+
+ def test_training(self):
+ pass
+
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ def test_inputs_embeds(self):
+ # FLAVA does not use inputs_embeds
+ pass
+
+ # skip this test as FLAVAMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ # skip this test as FLAVAMultimodalModel has no base class and is
+ # not available in MODEL_MAPPING
+ def test_save_load_fast_init_to_base(self):
+ pass
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FLAVAMultimodalModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class FLAVAModelTester:
+ def __init__(self, parent, is_training=True):
+ self.parent = parent
+ self.text_model_tester = FLAVATextModelTester(parent)
+ self.vision_model_tester = FLAVAImageModelTester(parent)
+ self.is_training = is_training
+
+ def prepare_config_and_inputs(self):
+ text_config, token_type_ids, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
+ vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
+
+ config = self.get_config()
+
+ return config, input_ids, token_type_ids, attention_mask, pixel_values
+
+ def get_config(self):
+ return FLAVAConfig.from_text_vision_configs(
+ self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
+ )
+
+ def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
+ model = FLAVAModel(config).to(torch_device).eval()
+ with torch.no_grad():
+ result = model(input_ids, pixel_values, attention_mask)
+ self.parent.assertEqual(
+ result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
+ )
+ self.parent.assertEqual(
+ result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, input_ids, token_type_ids, attention_mask, pixel_values = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ "pixel_values": pixel_values,
+ "return_loss": True,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class FLAVAModelTest(ModelTesterMixin, unittest.TestCase):
+ all_model_classes = (FLAVAModel,) if is_torch_available() else ()
+ test_head_masking = False
+ test_pruning = False
+ test_resize_embeddings = False
+ test_attention_outputs = False
+
+ def setUp(self):
+ self.model_tester = FLAVAModelTester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ # hidden_states are tested in individual model tests
+ def test_hidden_states_output(self):
+ pass
+
+ # input_embeds are tested in individual model tests
+ def test_inputs_embeds(self):
+ pass
+
+ # tested in individual model tests
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ # FLAVAModel does not have input/output embeddings
+ def test_model_common_attributes(self):
+ pass
+
+ # override as the `logit_scale` parameter initilization is different for FLAVA
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ # check if `logit_scale` is initilized as per the original implementation
+ if name == "logit_scale":
+ self.assertAlmostEqual(
+ param.data.item(),
+ np.log(1 / 0.07),
+ delta=1e-3,
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torchscript(self, config, inputs_dict):
+ if not self.test_torchscript:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.torchscript = True
+ configs_no_init.return_dict = False
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+
+ try:
+ input_ids = inputs_dict["input_ids"]
+ pixel_values = inputs_dict["pixel_values"] # FLAVA needs pixel_values
+ traced_model = torch.jit.trace(model, (input_ids, pixel_values))
+ except RuntimeError:
+ self.fail("Couldn't trace module.")
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+ try:
+ torch.jit.save(traced_model, pt_file_name)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ loaded_model = torch.jit.load(pt_file_name)
+ except Exception:
+ self.fail("Couldn't load module.")
+
+ model.to(torch_device)
+ model.eval()
+
+ loaded_model.to(torch_device)
+ loaded_model.eval()
+
+ model_state_dict = model.state_dict()
+ loaded_model_state_dict = loaded_model.state_dict()
+
+ self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+ models_equal = True
+ for layer_name, p1 in model_state_dict.items():
+ p2 = loaded_model_state_dict[layer_name]
+ if p1.data.ne(p2.data).sum() > 0:
+ models_equal = False
+
+ self.assertTrue(models_equal)
+
+ def test_load_vision_text_config(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # Save FLAVAConfig and check if we can load FLAVAImageConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ vision_config = FLAVAImageConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
+
+ # Save FLAVAConfig and check if we can load FLAVATextConfig from it
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ config.save_pretrained(tmp_dir_name)
+ text_config = FLAVATextConfig.from_pretrained(tmp_dir_name)
+ self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
+
+ # overwrite from common since FLAVAModel/TFFLAVAModel return FLAVAOutput/TFFLAVAOutput
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = FLAVAModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@require_vision
+@require_torch
+class FLAVAModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference(self):
+ model_name = ""
+ model = FLAVAForPretraining.from_pretrained(model_name).to(torch_device)
+ processor = FLAVAProcessor.from_pretrained(model_name)
+
+ image = prepare_img()
+ inputs = processor(
+ text=["a photo of a cat", "a photo of a dog"],
+ images=[image, image],
+ padding="max_length",
+ max_length=77,
+ return_tensors="pt",
+ return_masks=True,
+ ).to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ self.assertEqual(
+ outputs.contrastive_logits_per_image.shape,
+ torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
+ )
+ self.assertEqual(
+ outputs.contrastive_logits_per_text.shape,
+ torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
+ )
+
+ expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
diff --git a/tests/flava/test_processor_flava.py b/tests/flava/test_processor_flava.py
new file mode 100644
index 00000000000000..b33d42ae5fa520
--- /dev/null
+++ b/tests/flava/test_processor_flava.py
@@ -0,0 +1,204 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+
+from transformers import BertTokenizer, BertTokenizerFast
+from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
+from transformers.testing_utils import require_vision
+from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import FLAVAFeatureExtractor, FLAVAProcessor
+ from transformers.models.flava.feature_extraction_flava import FLAVA_IMAGE_MEAN, FLAVA_IMAGE_STD
+
+
+@require_vision
+class FLAVAProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.tmpdirname = tempfile.mkdtemp()
+ vocab_tokens = [
+ "[UNK]",
+ "[CLS]",
+ "[SEP]",
+ "[PAD]",
+ "[MASK]",
+ "want",
+ "##want",
+ "##ed",
+ "wa",
+ "un",
+ "runn",
+ "##ing",
+ ",",
+ "low",
+ "lowest",
+ ]
+ self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
+
+ with open(self.vocab_file, "w", encoding="utf-8") as fp:
+ fp.write("".join([x + "\n" for x in vocab_tokens]))
+
+ feature_extractor_map = {
+ "image_mean": FLAVA_IMAGE_MEAN,
+ "image_std": FLAVA_IMAGE_STD,
+ "do_normalize": True,
+ "do_resize": True,
+ "size": 224,
+ "do_center_crop": True,
+ "crop_size": 224,
+ "input_size_patches": 14,
+ "total_mask_patches": 75,
+ "mask_group_max_patches": None,
+ "mask_group_min_patches": 16,
+ "mask_group_min_aspect_ratio": 0.3,
+ "mask_group_max_aspect_ratio": None,
+ }
+
+ self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
+ with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
+ json.dump(feature_extractor_map, fp)
+
+ def get_tokenizer(self, **kwargs):
+ return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_rust_tokenizer(self, **kwargs):
+ return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
+
+ def get_feature_extractor(self, **kwargs):
+ return FLAVAFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ def prepare_image_inputs(self):
+ """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
+ or a list of PyTorch tensors if one specifies torchify=True.
+ """
+
+ image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
+
+ image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
+
+ return image_inputs
+
+ def test_save_load_pretrained_default(self):
+ tokenizer_slow = self.get_tokenizer()
+ tokenizer_fast = self.get_rust_tokenizer()
+ feature_extractor = self.get_feature_extractor()
+
+ processor_slow = FLAVAProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor)
+ processor_slow.save_pretrained(self.tmpdirname)
+ processor_slow = FLAVAProcessor.from_pretrained(self.tmpdirname, use_fast=False)
+
+ processor_fast = FLAVAProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor)
+ processor_fast.save_pretrained(self.tmpdirname)
+ processor_fast = FLAVAProcessor.from_pretrained(self.tmpdirname)
+
+ self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
+ self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
+ self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
+ self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string())
+ self.assertIsInstance(processor_slow.feature_extractor, FLAVAFeatureExtractor)
+ self.assertIsInstance(processor_fast.feature_extractor, FLAVAFeatureExtractor)
+
+ def test_save_load_pretrained_additional_features(self):
+ processor = FLAVAProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
+ processor.save_pretrained(self.tmpdirname)
+
+ tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
+ feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
+
+ processor = FLAVAProcessor.from_pretrained(
+ self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
+ )
+
+ self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
+ self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
+
+ self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
+ self.assertIsInstance(processor.feature_extractor, FLAVAFeatureExtractor)
+
+ def test_feature_extractor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ image_input = self.prepare_image_inputs()
+
+ input_feat_extract = feature_extractor(image_input, return_tensors="np")
+ input_processor = processor(images=image_input, return_tensors="np")
+
+ for key in input_feat_extract.keys():
+ self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
+
+ def test_tokenizer(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+
+ encoded_processor = processor(text=input_str)
+
+ encoded_tok = tokenizer(input_str)
+
+ for key in encoded_tok.keys():
+ self.assertListEqual(encoded_tok[key], encoded_processor[key])
+
+ def test_processor(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ input_str = "lower newer"
+ image_input = self.prepare_image_inputs()
+
+ inputs = processor(text=input_str, images=image_input)
+
+ self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
+
+ # test if it raises when no input is passed
+ with pytest.raises(ValueError):
+ processor()
+
+ def test_tokenizer_decode(self):
+ feature_extractor = self.get_feature_extractor()
+ tokenizer = self.get_tokenizer()
+
+ processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
+
+ predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
+
+ decoded_processor = processor.batch_decode(predicted_ids)
+ decoded_tok = tokenizer.batch_decode(predicted_ids)
+
+ self.assertListEqual(decoded_tok, decoded_processor)