diff --git a/README.md b/README.md index cbc2e33ee7d42..105d2b03c833d 100644 --- a/README.md +++ b/README.md @@ -265,7 +265,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc_flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. +1. **[FLAVA](https://huggingface.co/docs/transformers/main/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. diff --git a/README_ko.md b/README_ko.md index bd518e8c10ef8..6899b3e7f5328 100644 --- a/README_ko.md +++ b/README_ko.md @@ -244,7 +244,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc_flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. +1. **[FLAVA](https://huggingface.co/docs/transformers/main/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. diff --git a/README_zh-hans.md b/README_zh-hans.md index 2abb8b5c0759d..1c6c1928acc8d 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -268,7 +268,7 @@ conda install -c huggingface transformers 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (来自 Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning 发布。 1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (来自 Google Research) 伴随论文 [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) 由 Sascha Rothe, Shashi Narayan, Aliaksei Severyn 发布。 1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (来自 CNRS) 伴随论文 [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) 由 Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab 发布。 -1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc_flava)** (来自 Facebook AI) 伴随论文 [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) 由 Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela 发布。 +1. **[FLAVA](https://huggingface.co/docs/transformers/main/model_doc/flava)** (来自 Facebook AI) 伴随论文 [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) 由 Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela 发布。 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (来自 Google Research) 伴随论文 [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) 由 James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon 发布。 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (来自 CMU/Google Brain) 伴随论文 [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) 由 Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le 发布。 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (来自 KAIST) 伴随论文 [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) 由 Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index f5e9dfefdaff8..5b38ef8dccb0f 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -280,7 +280,7 @@ conda install -c huggingface transformers 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[FlauBERT](https://huggingface.co/docs/transformers/model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -1. **[FLAVA](https://huggingface.co/docs/transformers/model_doc_flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. +1. **[FLAVA](https://huggingface.co/docs/transformers/main/model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](https://huggingface.co/docs/transformers/model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](https://huggingface.co/docs/transformers/model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GLPN](https://huggingface.co/docs/transformers/main/model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 278d90adfff4d..457e798ae21aa 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -86,7 +86,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[EncoderDecoder](model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[ELECTRA](model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. 1. **[FlauBERT](model_doc/flaubert)** (from CNRS) released with the paper [FlauBERT: Unsupervised Language Model Pre-training for French](https://arxiv.org/abs/1912.05372) by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -1. **[FLAVA](model_doc_flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. +1. **[FLAVA](model_doc/flava)** (from Facebook AI) released with the paper [FLAVA: A Foundational Language And Vision Alignment Model](https://arxiv.org/abs/2112.04482) by Amanpreet Singh, Ronghang Hu, Vedanuj Goswami, Guillaume Couairon, Wojciech Galuba, Marcus Rohrbach, and Douwe Kiela. 1. **[FNet](model_doc/fnet)** (from Google Research) released with the paper [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon. 1. **[Funnel Transformer](model_doc/funnel)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GLPN](model_doc/glpn)** (from KAIST) released with the paper [Global-Local Path Networks for Monocular Depth Estimation with Vertical CutDepth](https://arxiv.org/abs/2201.07436) by Doyeon Kim, Woonghyun Ga, Pyungwhan Ahn, Donggyu Joo, Sehwan Chun, Junmo Kim. @@ -205,7 +205,7 @@ Flax), PyTorch, and/or TensorFlow. | Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | | FlauBERT | ✅ | ❌ | ✅ | ✅ | ❌ | -| FLAVA | ❌ | ❌ | ✅ | ❌ | ❌ | +| Flava | ❌ | ❌ | ✅ | ❌ | ❌ | | FNet | ✅ | ✅ | ✅ | ❌ | ❌ | | Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ | | GLPN | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/flava.mdx b/docs/source/en/model_doc/flava.mdx index b65e83b91c492..531268c85295d 100644 --- a/docs/source/en/model_doc/flava.mdx +++ b/docs/source/en/model_doc/flava.mdx @@ -1,4 +1,4 @@ - +This model was contributed by [aps](https://huggingface.co/aps). The original code can be found [here](https://github.com/facebookresearch/multimodal/tree/main/examples/flava). -This model was contributed by [aps](https://huggingface.co/aps). - +## FlavaConfig -## FLAVAConfig +[[autodoc]] FlavaConfig -[[autodoc]] FLAVAConfig - - from_configs +## FlavaTextConfig -## FLAVATextConfig +[[autodoc]] FlavaTextConfig -[[autodoc]] FLAVATextConfig +## FlavaImageConfig -## FLAVAImageConfig +[[autodoc]] FlavaImageConfig -[[autodoc]] FLAVAImageConfig +## FlavaMultimodalConfig -## FLAVAMultimodalConfig +[[autodoc]] FlavaMultimodalConfig -[[autodoc]] FLAVAMultimodalConfig +## FlavaImageCodebookConfig -## FLAVACodebookConfig +[[autodoc]] FlavaImageCodebookConfig -[[autodoc]] FLAVACodebookConfig +## FlavaProcessor -## FLAVAProcessor +[[autodoc]] FlavaProcessor -[[autodoc]] FLAVAProcessor +## FlavaFeatureExtractor -## FLAVAFeatureExtractor +[[autodoc]] FlavaFeatureExtractor -[[autodoc]] FLAVAFeatureExtractor +## FlavaForPreTraining -## FLAVACodebookFeatureExtractor - -[[autodoc]] FLAVACodebookFeatureExtractor - -## FLAVAForPreTraining - -[[autodoc]] FLAVAForPreTraining +[[autodoc]] FlavaForPreTraining - forward -## FLAVAModel +## FlavaModel -[[autodoc]] FLAVAModel +[[autodoc]] FlavaModel - forward - get_text_features - get_image_features -## FLAVACodebook +## FlavaImageCodebook -[[autodoc]] FLAVACodebook +[[autodoc]] FlavaImageCodebook - forward - get_codebook_indices - get_codebook_probs -## FLAVATextModel +## FlavaTextModel -[[autodoc]] FLAVATextModel +[[autodoc]] FlavaTextModel - forward -## FLAVAImageModel +## FlavaImageModel -[[autodoc]] FLAVAImageModel +[[autodoc]] FlavaImageModel - forward -## FLAVAMultimodalModel +## FlavaMultimodalModel -[[autodoc]] FLAVAMultimodalModel +[[autodoc]] FlavaMultimodalModel - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fd82a56bc9927..ff009fb9060c7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -200,11 +200,11 @@ "models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"], "models.flava": [ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", - "FLAVACodebookConfig", - "FLAVAConfig", - "FLAVAImageConfig", - "FLAVAMultimodalConfig", - "FLAVATextConfig", + "FlavaConfig", + "FlavaImageCodebookConfig", + "FlavaImageConfig", + "FlavaMultimodalConfig", + "FlavaTextConfig", ], "models.fnet": ["FNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "FNetConfig"], "models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"], @@ -576,9 +576,7 @@ _import_structure["models.deit"].append("DeiTFeatureExtractor") _import_structure["models.detr"].append("DetrFeatureExtractor") _import_structure["models.dpt"].append("DPTFeatureExtractor") - _import_structure["models.flava"].append("FLAVAFeatureExtractor") - _import_structure["models.flava"].append("FLAVACodebookFeatureExtractor") - _import_structure["models.flava"].append("FLAVAProcessor") + _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor"]) _import_structure["models.glpn"].append("GLPNFeatureExtractor") _import_structure["models.imagegpt"].append("ImageGPTFeatureExtractor") _import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor") @@ -1051,13 +1049,13 @@ _import_structure["models.flava"].extend( [ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", - "FLAVACodebook", - "FLAVAForPreTraining", - "FLAVAImageModel", - "FLAVAModel", - "FLAVAMultimodalModel", - "FLAVAPreTrainedModel", - "FLAVATextModel", + "FlavaForPreTraining", + "FlavaImageCodebook", + "FlavaImageModel", + "FlavaModel", + "FlavaMultimodalModel", + "FlavaPreTrainedModel", + "FlavaTextModel", ] ) _import_structure["models.fnet"].extend( @@ -2678,11 +2676,11 @@ from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer from .models.flava import ( FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, - FLAVACodebookConfig, - FLAVAConfig, - FLAVAImageConfig, - FLAVAMultimodalConfig, - FLAVATextConfig, + FlavaConfig, + FlavaImageCodebookConfig, + FlavaImageConfig, + FlavaMultimodalConfig, + FlavaTextConfig, ) from .models.fnet import FNET_PRETRAINED_CONFIG_ARCHIVE_MAP, FNetConfig from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer @@ -3004,7 +3002,7 @@ from .models.deit import DeiTFeatureExtractor from .models.detr import DetrFeatureExtractor from .models.dpt import DPTFeatureExtractor - from .models.flava import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor, FLAVAProcessor + from .models.flava import FlavaFeatureExtractor, FlavaProcessor from .models.glpn import GLPNFeatureExtractor from .models.imagegpt import ImageGPTFeatureExtractor from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor @@ -3404,13 +3402,13 @@ ) from .models.flava import ( FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, - FLAVACodebook, - FLAVAForPreTraining, - FLAVAImageModel, - FLAVAModel, - FLAVAMultimodalModel, - FLAVAPreTrainedModel, - FLAVATextModel, + FlavaForPreTraining, + FlavaImageCodebook, + FlavaImageModel, + FlavaModel, + FlavaMultimodalModel, + FlavaPreTrainedModel, + FlavaTextModel, ) from .models.fnet import ( FNET_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0a64dd8814d06..0f70406ef3eee 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -66,7 +66,7 @@ ("canine", "CanineConfig"), ("roformer", "RoFormerConfig"), ("clip", "CLIPConfig"), - ("flava", "FLAVAConfig"), + ("flava", "FlavaConfig"), ("bigbird_pegasus", "BigBirdPegasusConfig"), ("deit", "DeiTConfig"), ("luke", "LukeConfig"), @@ -270,7 +270,7 @@ ("canine", "Canine"), ("roformer", "RoFormer"), ("clip", "CLIP"), - ("flava", "FLAVA"), + ("flava", "Flava"), ("bigbird_pegasus", "BigBirdPegasus"), ("deit", "DeiT"), ("luke", "LUKE"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index d0aa5c4c5c49e..233f4ff6c9f7c 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -47,7 +47,7 @@ ("detr", "DetrFeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("clip", "CLIPFeatureExtractor"), - ("flava", "FLAVAFeatureExtractor"), + ("flava", "FlavaFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), ("swin", "ViTFeatureExtractor"), ("vit_mae", "ViTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4cfe6a3edd80c..82e2dd2b83c77 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -62,7 +62,7 @@ ("canine", "CanineModel"), ("roformer", "RoFormerModel"), ("clip", "CLIPModel"), - ("flava", "FLAVAModel"), + ("flava", "FlavaModel"), ("bigbird_pegasus", "BigBirdPegasusModel"), ("deit", "DeiTModel"), ("luke", "LukeModel"), @@ -132,7 +132,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping - ("flava", "FLAVAForPreTraining"), + ("flava", "FlavaForPreTraining"), ("vit_mae", "ViTMAEForPreTraining"), ("fnet", "FNetForPreTraining"), ("visual_bert", "VisualBertForPreTraining"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 84cf6b8a7b6e4..0c0059c7c6575 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -226,13 +226,6 @@ "CLIPTokenizerFast" if is_tokenizers_available() else None, ), ), - # ( - # "flava", - # ( - # "CLIPTokenizer", - # "CLIPTokenizerFast" if is_tokenizers_available() else None, - # ), - # ), ("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)), ( "perceiver", diff --git a/src/transformers/models/flava/__init__.py b/src/transformers/models/flava/__init__.py index abf4f70f12bc0..29d8240032a43 100644 --- a/src/transformers/models/flava/__init__.py +++ b/src/transformers/models/flava/__init__.py @@ -2,7 +2,7 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,60 +17,80 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import _LazyModule, is_torch_available, is_vision_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available _import_structure = { "configuration_flava": [ "FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", - "FLAVACodebookConfig", - "FLAVAConfig", - "FLAVAImageConfig", - "FLAVAMultimodalConfig", - "FLAVATextConfig", + "FlavaConfig", + "FlavaImageCodebookConfig", + "FlavaImageConfig", + "FlavaMultimodalConfig", + "FlavaTextConfig", ], } -if is_vision_available(): - _import_structure["feature_extraction_flava"] = ["FLAVACodebookFeatureExtractor", "FLAVAFeatureExtractor"] - _import_structure["processing_flava"] = ["FLAVAProcessor"] +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["feature_extraction_flava"] = ["FlavaFeatureExtractor"] + _import_structure["processing_flava"] = ["FlavaProcessor"] -if is_torch_available(): +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: _import_structure["modeling_flava"] = [ "FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", - "FLAVACodebook", - "FLAVAForPreTraining", - "FLAVAImageModel", - "FLAVAModel", - "FLAVAMultimodalModel", - "FLAVAPreTrainedModel", - "FLAVATextModel", + "FlavaForPreTraining", + "FlavaImageCodebook", + "FlavaImageModel", + "FlavaModel", + "FlavaMultimodalModel", + "FlavaPreTrainedModel", + "FlavaTextModel", ] if TYPE_CHECKING: from .configuration_flava import ( FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, - FLAVACodebookConfig, - FLAVAConfig, - FLAVAImageConfig, - FLAVAMultimodalConfig, - FLAVATextConfig, + FlavaConfig, + FlavaImageCodebookConfig, + FlavaImageConfig, + FlavaMultimodalConfig, + FlavaTextConfig, ) - if is_vision_available(): - from .feature_extraction_flava import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor - from .processing_flava import FLAVAProcessor + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_flava import FlavaFeatureExtractor + from .processing_flava import FlavaProcessor - if is_torch_available(): + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: from .modeling_flava import ( FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, - FLAVACodebook, - FLAVAForPreTraining, - FLAVAImageModel, - FLAVAModel, - FLAVAMultimodalModel, - FLAVAPreTrainedModel, - FLAVATextModel, + FlavaForPreTraining, + FlavaImageCodebook, + FlavaImageModel, + FlavaModel, + FlavaMultimodalModel, + FlavaPreTrainedModel, + FlavaTextModel, ) else: diff --git a/src/transformers/models/flava/configuration_flava.py b/src/transformers/models/flava/configuration_flava.py index 4d881d97949aa..c42c90086406b 100644 --- a/src/transformers/models/flava/configuration_flava.py +++ b/src/transformers/models/flava/configuration_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" FLAVA model configuration""" +""" FLAVA model configurations""" import copy import os -from typing import Union +from typing import Any, Dict, Union from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -25,16 +25,17 @@ logger = logging.get_logger(__name__) FLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "flava-full": "https://huggingface.co/aps/flava-full/resolve/main/config.json", + "facebook/flava-full": "https://huggingface.co/facebook/flava-full/resolve/main/config.json", } -class FLAVAImageConfig(PretrainedConfig): +class FlavaImageConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`FLAVAImageModel`]. It is used to instantiate an - FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FLAVA - [full](https://huggingface.co/aps/flava-full) architecture. + This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -60,29 +61,30 @@ class FLAVAImageConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - image_size (`int`, *optional*, defaults to `224`): + image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to `16`): + patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. - num_channels (`int`, *optional*, defaults to `3`): + num_channels (`int`, *optional*, defaults to 3): The number of input channels. qkv_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the queries, keys and values. - mask_token (`bool`, *optional*, defaults to True): - Whether to use a mask token or not. Used in MIM loss. + mask_token (`bool`, *optional*, defaults to `True`): + Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA. vocab_size (`int`, *optional*, defaults to 8192): - Vocabulary size of the [`FLAVACodebook`] used in conjunction with [`FLAVAImageModel`] for MIM. + Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked + Image Modeling) loss for FLAVA. Example: ```python - >>> from transformers import FLAVAImageModel, FLAVAImageConfig + >>> from transformers import FlavaImageModel, FlavaImageConfig - >>> # Initializing a FLAVAImageModel with style configuration - >>> configuration = FLAVAImageConfig() + >>> # Initializing a FlavaImageModel with style configuration + >>> configuration = FlavaImageConfig() - >>> # Initializing a FLAVAImageModel model from the style configuration - >>> model = FLAVAImageModel(configuration) + >>> # Initializing a FlavaImageModel model from the style configuration + >>> model = FlavaImageModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -92,21 +94,21 @@ class FLAVAImageConfig(PretrainedConfig): def __init__( self, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - initializer_range=0.02, - layer_norm_eps=1e-12, - image_size=224, - patch_size=16, - num_channels=3, - qkv_bias=True, - mask_token=True, - vocab_size=8192, + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: int = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + image_size: int = 224, + patch_size: int = 16, + num_channels: int = 3, + qkv_bias: bool = True, + mask_token: bool = True, + vocab_size: int = 8192, **kwargs ): super().__init__(**kwargs) @@ -132,7 +134,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # get the image config dict if we are loading from FLAVAConfig + # get the image config dict if we are loading from FlavaConfig if config_dict.get("model_type") == "flava": config_dict = config_dict["image_config"] @@ -145,12 +147,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -class FLAVATextConfig(PretrainedConfig): +class FlavaTextConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`FLAVATextModel`]. It is used to instantiate an - FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FLAVA - [full](https://huggingface.co/aps/flava-full) architecture. + This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an + FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -159,9 +162,9 @@ class FLAVATextConfig(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 30522): Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`FLAVATextModel`]. + `inputs_ids` passed when calling [`FlavaTextModel`]. type_vocab_size (`int`, *optional*, defaults to 2): - The vocabulary size of the `token_type_ids` passed when calling [`FLAVATextModel`]. Note that even though + The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is used similar to RoBERTa. max_position_embeddings (`int`, *optional*, defaults to 512): @@ -192,11 +195,11 @@ class FLAVATextConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. - image_size (`int`, *optional*, defaults to `224`): + image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to `16`): + patch_size (`int`, *optional*, defaults to 16): The size (resolution) of each patch. - num_channels (`int`, *optional*, defaults to `3`): + num_channels (`int`, *optional*, defaults to 3): The number of input channels. qkv_bias (`bool`, *optional*, defaults to `True`): Whether to add a bias to the queries, keys and values. @@ -204,13 +207,13 @@ class FLAVATextConfig(PretrainedConfig): Example: ```python - >>> from transformers import FLAVATextModel, FLAVATextConfig + >>> from transformers import FlavaTextModel, FlavaTextConfig - >>> # Initializing a FLAVATextModel with style configuration - >>> configuration = FLAVATextConfig() + >>> # Initializing a FlavaTextModel with style configuration + >>> configuration = FlavaTextConfig() - >>> # Initializing a FLAVATextConfig from the style configuration - >>> model = FLAVATextModel(configuration) + >>> # Initializing a FlavaTextConfig from the style configuration + >>> model = FlavaTextModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -219,21 +222,21 @@ class FLAVATextConfig(PretrainedConfig): def __init__( self, - vocab_size=30522, - type_vocab_size=2, - max_position_embeddings=512, - position_embedding_type="absolute", - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - initializer_range=0.02, - layer_norm_eps=1e-12, - pad_token_id=0, - qkv_bias=True, + vocab_size: int = 30522, + type_vocab_size: int = 2, + max_position_embeddings: int = 512, + position_embedding_type: str = "absolute", + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + pad_token_id: int = 0, + qkv_bias: bool = True, **kwargs ): super().__init__(**kwargs) @@ -259,7 +262,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # get the text config dict if we are loading from FLAVAConfig + # get the text config dict if we are loading from FlavaConfig if config_dict.get("model_type") == "flava": config_dict = config_dict["text_config"] @@ -272,12 +275,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -class FLAVAMultimodalConfig(PretrainedConfig): +class FlavaMultimodalConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`FLAVAMultimodalModel`]. It is used to instantiate - an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FLAVA - [full](https://huggingface.co/aps/flava-full) architecture. + This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate + an FLAVA model according to the specified arguments, defining the model architecture. + + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -312,13 +316,13 @@ class FLAVAMultimodalConfig(PretrainedConfig): Example: ```python - >>> from transformers import FLAVAMultimodalModel, FLAVAMultimodalConfig + >>> from transformers import FlavaMultimodalModel, FlavaMultimodalConfig - >>> # Initializing a FLAVAMultimodalModel with style configuration - >>> configuration = FLAVAMultimodalConfig() + >>> # Initializing a FlavaMultimodalModel with style configuration + >>> configuration = FlavaMultimodalConfig() - >>> # Initializing a FLAVAMultimodalModel model from the style configuration - >>> model = FLAVAMultimodalModel(configuration) + >>> # Initializing a FlavaMultimodalModel model from the style configuration + >>> model = FlavaMultimodalModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -328,17 +332,17 @@ class FLAVAMultimodalConfig(PretrainedConfig): def __init__( self, - hidden_size=768, - num_hidden_layers=6, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - initializer_range=0.02, - layer_norm_eps=1e-12, - qkv_bias=True, - use_cls_token=True, + hidden_size: int = 768, + num_hidden_layers: int = 6, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: int = "gelu", + hidden_dropout_prob: int = 0.0, + attention_probs_dropout_prob: int = 0.0, + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + qkv_bias: bool = True, + use_cls_token: bool = True, **kwargs ): super().__init__(**kwargs) @@ -359,7 +363,7 @@ def __init__( def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # get the image config dict if we are loading from FLAVAConfig + # get the multimodal config dict if we are loading from FlavaConfig if config_dict.get("model_type") == "flava": config_dict = config_dict["multimodal_config"] @@ -372,14 +376,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return cls.from_dict(config_dict, **kwargs) -class FLAVACodebookConfig(PretrainedConfig): - model_type = "flava_codebook" +class FlavaImageCodebookConfig(PretrainedConfig): + model_type = "flava_image_codebook" r""" - [`FLAVACodebookConfig`] is the configuration class to store the configuration of a [`FLAVACodebook`]. It is used to - instantiate an FLAVA model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the FLAVA - [codebook](https://huggingface.co/aps/flava-codebook) architecture + [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It + is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA + [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -396,7 +400,7 @@ class FLAVACodebookConfig(PretrainedConfig): Size of hidden dim for the blocks. vocab_size (`int`, defaults to 8192): Size of the output vocabulary for the codebook. - freeze (`bool`, defaults to True): + freeze (`bool`, defaults to `True`): Whether to freeze the weights of the model. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. @@ -406,13 +410,13 @@ class FLAVACodebookConfig(PretrainedConfig): Example: ```python - >>> from transformers import FLAVACodebook, FLAVACodebookConfig + >>> from transformers import FlavaImageCodebook, FlavaImageCodebookConfig - >>> # Initializing a FLAVACodebook with style configuration - >>> configuration = FLAVACodebookConfig() + >>> # Initializing a FlavaImageCodebook with style configuration + >>> configuration = FlavaImageCodebookConfig() - >>> # Initializing a FLAVACodebook model from the style configuration - >>> model = FLAVACodebook(configuration) + >>> # Initializing a FlavaImageCodebook model from the style configuration + >>> model = FlavaImageCodebook(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` @@ -420,13 +424,13 @@ class FLAVACodebookConfig(PretrainedConfig): def __init__( self, - num_groups=4, - input_channels=3, - num_blocks_per_group=2, - hidden_size=256, - vocab_size=8192, - freeze=True, - initializer_range=0.02, + num_groups: int = 4, + input_channels: int = 3, + num_blocks_per_group: int = 2, + hidden_size: int = 256, + vocab_size: int = 8192, + freeze: int = True, + initializer_range: float = 0.02, **kwargs, ): super().__init__(**kwargs) @@ -438,23 +442,41 @@ def __init__( self.freeze = freeze self.initializer_range = initializer_range + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the image codebook config dict if we are loading from FlavaConfig + if config_dict.get("model_type") == "flava": + config_dict = config_dict["image_codebook_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + -class FLAVAConfig(PretrainedConfig): +class FlavaConfig(PretrainedConfig): r""" - [`FLAVAConfig`] is the configuration class to store the configuration of a [`FLAVAModel`]. It is used to - instantiate FLAVA model according to the specified arguments, defining the text model, image model and multimodal - model configs. + [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to + instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook + and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to + that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: text_config_dict (`dict`, *optional*): - Dictionary of configuration options used to initialize [`FLAVATextConfig`]. + Dictionary of configuration options used to initialize [`FlavaTextConfig`]. image_config_dict (`dict`, *optional*): - Dictionary of configuration options used to initialize [`FLAVAImageConfig`]. + Dictionary of configuration options used to initialize [`FlavaImageConfig`]. multimodal_config_dict (`dict`, *optional*): - Dictionary of configuration options used to initialize [`FLAVAMultimodalConfig`]. + Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. layer_norm_eps (`float`, *optional*, defaults to 1e-12): @@ -469,9 +491,9 @@ class FLAVAConfig(PretrainedConfig): ce_ignore_index (`int`, *optional*, defaults to -100): Cross entropy index to ignore. mim_weight (`float`, *optional*, defaults to 1.0): - Weight to be assigned to MIM unimodal loss + Weight to be assigned to MIM (Masked Image Modeling) unimodal loss mlm_weight (`float`, *optional*, defaults to 1.0): - Weight to be assigned to MLM unimodal loss + Weight to be assigned to MLM (Masked Language Modeling) unimodal loss global_contrastive_weight (`float`, *optional*, defaults to 1.0): Weight to be assigned to global contrastive cross-alignment loss. itm_weight (`float`, *optional*, defaults to 1.0): @@ -480,11 +502,11 @@ class FLAVAConfig(PretrainedConfig): Weight to be assigned to MMM loss's image part. mmm_text_weight (`float`, *optional*, defaults to 1.0): Weight to be assigned to MMM loss's text part. - global_backprop_contrastive (`bool`, *optional*, defaults to True): + global_backprop_contrastive (`bool`, *optional*, defaults to `True`): Whether to use global backpropgation through all workers in contrastive loss. - skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to True): + skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`): Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses. - return_loss (`bool`, *optional*, defaults to True): + return_loss (`bool`, *optional*, defaults to `True`): Whether to return loss or not kwargs (*optional*): @@ -493,14 +515,14 @@ class FLAVAConfig(PretrainedConfig): Example: ```python - >>> from transformers import FLAVAModel, FLAVAForPreTraining, FLAVAConfig + >>> from transformers import FlavaModel, FlavaForPreTraining, FlavaConfig - >>> # Initializing a FLAVAConfig with style configuration - >>> configuration = FLAVAConfig() + >>> # Initializing a FlavaConfig with style configuration + >>> configuration = FlavaConfig() - >>> # Initializing a FLAVAModel and FLAVAForPreTraining model from the style configuration - >>> model = FLAVAModel(configuration) - >>> model_pre = FLAVAForPreTraining(configuration) + >>> # Initializing a FlavaModel and FlavaForPreTraining model from the style configuration + >>> model = FlavaModel(configuration) + >>> model_pre = FlavaForPreTraining(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -513,49 +535,59 @@ class FLAVAConfig(PretrainedConfig): def __init__( self, - image_config_dict=None, - text_config_dict=None, - multimodal_config_dict=None, - hidden_size=768, - layer_norm_eps=1e-12, - projection_dim=768, - logit_scale_init_value=2.6592, - initializer_range=0.02, - ce_ignore_index=-100, - mim_weight=1.0, - mlm_weight=1.0, - global_contrastive_weight=1.0, - itm_weight=1.0, - mmm_image_weight=1.0, - mmm_text_weight=1.0, - global_backprop_contrastive=True, - skip_unmasked_multimodal_encoder=True, - return_loss=True, + image_config_dict: Dict[str, Any] = None, + text_config_dict: Dict[str, Any] = None, + multimodal_config_dict: Dict[str, Any] = None, + image_codebook_config_dict: Dict[str, Any] = None, + hidden_size: int = 768, + layer_norm_eps: float = 1e-12, + projection_dim: int = 768, + init_codebook: bool = True, + logit_scale_init_value: float = 2.6592, + initializer_range: float = 0.02, + ce_ignore_index: int = -100, + mim_weight: float = 1.0, + mlm_weight: float = 1.0, + global_contrastive_weight: float = 1.0, + itm_weight: float = 1.0, + mmm_image_weight: float = 1.0, + mmm_text_weight: float = 1.0, + global_backprop_contrastive: bool = True, + skip_unmasked_multimodal_encoder: bool = True, + return_loss: bool = True, **kwargs ): - super().__init__( - text_config_dict=text_config_dict, - image_config_dict=image_config_dict, - multimodal_config_dict=multimodal_config_dict, - **kwargs, - ) + super().__init__(**kwargs) if image_config_dict is None: image_config_dict = {} - logger.info("image_config_dict is None. initializing the FLAVAImageConfig with default values.") + logger.info("image_config_dict is None. initializing the FlavaImageConfig with default values.") if text_config_dict is None: text_config_dict = {} - logger.info("text_config_dict is None. Initializing the FLAVATextConfig with default values.") + logger.info("text_config_dict is None. Initializing the FlavaTextConfig with default values.") if multimodal_config_dict is None: multimodal_config_dict = {} - logger.info("multimodal_config_dict is None. initializing the FLAVAImageConfig with default values.") + logger.info("multimodal_config_dict is None. initializing the FlavaMultimodalConfig with default values.") + + if image_codebook_config_dict is None: + image_codebook_config_dict = {} + logger.info( + "image_codebook_config_dict is None. initializing the FlavaImageCodebookConfig with default values." + ) + + self.image_config_dict = image_config_dict + self.text_config_dict = text_config_dict + self.multimodal_config_dict = multimodal_config_dict + self.image_codebook_config_dict = image_codebook_config_dict - self.image_config = FLAVAImageConfig(**image_config_dict) - self.text_config = FLAVATextConfig(**text_config_dict) - self.multimodal_config = FLAVAMultimodalConfig(**multimodal_config_dict) + self.image_config = FlavaImageConfig(**self.image_config_dict) + self.text_config = FlavaTextConfig(**self.text_config_dict) + self.multimodal_config = FlavaMultimodalConfig(**self.multimodal_config_dict) + self.image_codebook_config = FlavaImageCodebookConfig(**self.image_codebook_config_dict) self.projection_dim = projection_dim + self.init_codebook = init_codebook self.hidden_size = hidden_size self.layer_norm_eps = layer_norm_eps @@ -576,23 +608,25 @@ def __init__( @classmethod def from_configs( cls, - image_config: FLAVAImageConfig, - text_config: FLAVATextConfig, - multimodal_config: FLAVAMultimodalConfig, + image_config: FlavaImageConfig, + text_config: FlavaTextConfig, + multimodal_config: FlavaMultimodalConfig, + image_codebook_config: FlavaImageCodebookConfig, **kwargs ): r""" - Instantiate a [`FLAVAConfig`] (or a derived class) from flava text model configuration, flava image model - configuration and flava multimodal model configuration. + Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model + configuration, flava multimodal model and flava codebook model configuration. Returns: - [`FLAVAConfig`]: An instance of a configuration object + [`FlavaConfig`]: An instance of a configuration object """ return cls( image_config_dict=image_config.to_dict(), text_config_dict=text_config.to_dict(), multimodal_config_dict=multimodal_config.to_dict(), + image_codebook_config_dict=image_codebook_config.to_dict(), **kwargs, ) @@ -607,5 +641,6 @@ def to_dict(self): output["image_config"] = self.image_config.to_dict() output["text_config"] = self.text_config.to_dict() output["multimodal_config"] = self.multimodal_config.to_dict() + output["image_codebook_config"] = self.image_codebook_config.to_dict() output["model_type"] = self.__class__.model_type return output diff --git a/src/transformers/models/flava/convert_dalle_to_flava_codebook.py b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py index bd99528ed9aeb..7b544125114c8 100644 --- a/src/transformers/models/flava/convert_dalle_to_flava_codebook.py +++ b/src/transformers/models/flava/convert_dalle_to_flava_codebook.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import torch -from transformers import FLAVACodebook, FLAVACodebookConfig +from transformers import FlavaImageCodebook, FlavaImageCodebookConfig def rreplace(s, old, new, occurrence): @@ -34,7 +34,15 @@ def count_parameters(state_dict): def upgrade_state_dict(state_dict): upgrade = {} + group_keys = ["group_1", "group_2", "group_3", "group_4"] for key, value in state_dict.items(): + for group_key in group_keys: + if group_key in key: + key = key.replace(f"{group_key}.", f"{group_key}.group.") + + if "res_path" in key: + key = key.replace("res_path.", "res_path.path.") + if key.endswith(".w"): key = rreplace(key, ".w", ".weight", 1) if key.endswith(".b"): @@ -46,7 +54,7 @@ def upgrade_state_dict(state_dict): @torch.no_grad() -def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): +def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None, save_checkpoint=True): """ Copy/paste/tweak model's weights to transformers design. """ @@ -63,11 +71,11 @@ def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_p encoder.load_state_dict(ckpt) if config_path is not None: - config = FLAVACodebookConfig.from_pretrained(config_path) + config = FlavaImageCodebookConfig.from_pretrained(config_path) else: - config = FLAVACodebookConfig() + config = FlavaImageCodebookConfig() - hf_model = FLAVACodebook(config).eval() + hf_model = FlavaImageCodebook(config).eval() state_dict = encoder.state_dict() hf_state_dict = upgrade_state_dict(state_dict) @@ -78,13 +86,16 @@ def convert_dalle_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_p assert torch.allclose(hf_count, state_dict_count, atol=1e-3) - hf_model.save_pretrained(pytorch_dump_folder_path) + if save_checkpoint: + hf_model.save_pretrained(pytorch_dump_folder_path) + else: + return hf_state_dict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") - parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") args = parser.parse_args() diff --git a/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py index e84c9a793441d..95ebb2bfdb236 100644 --- a/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py +++ b/src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ import torch -from transformers import FLAVAConfig, FLAVAForPreTraining +from transformers import FlavaConfig, FlavaForPreTraining +from transformers.models.flava.convert_dalle_to_flava_codebook import convert_dalle_checkpoint def count_parameters(state_dict): @@ -26,7 +27,7 @@ def count_parameters(state_dict): return sum(param.float().sum() if "encoder.embeddings" not in key else 0 for key, param in state_dict.items()) -def upgrade_state_dict(state_dict): +def upgrade_state_dict(state_dict, codebook_state_dict): upgrade = {} for key, value in state_dict.items(): @@ -51,31 +52,36 @@ def upgrade_state_dict(state_dict): upgrade[key] = value.float() + for key, value in codebook_state_dict.items(): + upgrade[f"image_codebook.{key}"] = value + return upgrade @torch.no_grad() -def convert_flava_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): +def convert_flava_checkpoint(checkpoint_path, codebook_path, pytorch_dump_folder_path, config_path=None): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: - config = FLAVAConfig.from_pretrained(config_path) + config = FlavaConfig.from_pretrained(config_path) else: - config = FLAVAConfig() + config = FlavaConfig() + + hf_model = FlavaForPreTraining(config).eval() - hf_model = FLAVAForPreTraining(config).eval() + codebook_state_dict = convert_dalle_checkpoint(codebook_path, None, save_checkpoint=False) if os.path.exists(checkpoint_path): state_dict = torch.load(checkpoint_path, map_location="cpu") else: state_dict = torch.hub.load_state_dict_from_url(checkpoint_path, map_location="cpu") - hf_state_dict = upgrade_state_dict(state_dict) + hf_state_dict = upgrade_state_dict(state_dict, codebook_state_dict) hf_model.load_state_dict(hf_state_dict) hf_state_dict = hf_model.state_dict() hf_count = count_parameters(hf_state_dict) - state_dict_count = count_parameters(state_dict) + state_dict_count = count_parameters(state_dict) + count_parameters(codebook_state_dict) assert torch.allclose(hf_count, state_dict_count, atol=1e-3) @@ -85,8 +91,9 @@ def convert_flava_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_p if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") - parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to flava checkpoint") + parser.add_argument("--codebook_path", default=None, type=str, help="Path to flava codebook checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") args = parser.parse_args() - convert_flava_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) + convert_flava_checkpoint(args.checkpoint_path, args.codebook_path, args.pytorch_dump_folder_path, args.config_path) diff --git a/src/transformers/models/flava/feature_extraction_flava.py b/src/transformers/models/flava/feature_extraction_flava.py index 0d5eaffaa1465..c3aba8c70b6ce 100644 --- a/src/transformers/models/flava/feature_extraction_flava.py +++ b/src/transformers/models/flava/feature_extraction_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import math import random from functools import lru_cache -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -39,15 +39,15 @@ # Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py -class MaskingGenerator: +class FlavaMaskingGenerator: def __init__( self, - input_size: Union[int, Tuple[int, int]], + input_size: Union[int, Tuple[int, int]] = 14, total_mask_patches: int = 75, - mask_group_max_patches: int = None, - mask_group_min_patches: Optional[int] = 16, - mask_group_min_aspect_ratio: float = 0.3, - mask_group_max_aspect_ratio: Optional[float] = None, + mask_group_max_patches: Optional[int] = None, + mask_group_min_patches: int = 16, + mask_group_min_aspect_ratio: Optional[float] = 0.3, + mask_group_max_aspect_ratio: float = None, ): if not isinstance(input_size, tuple): input_size = (input_size,) * 2 @@ -82,17 +82,17 @@ def _mask(self, mask, max_mask_patches): for _attempt in range(10): target_area = random.uniform(self.mask_group_min_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < self.width and h < self.height: - top = random.randint(0, self.height - h) - left = random.randint(0, self.width - w) + height = int(round(math.sqrt(target_area * aspect_ratio))) + width = int(round(math.sqrt(target_area / aspect_ratio))) + if width < self.width and height < self.height: + top = random.randint(0, self.height - height) + left = random.randint(0, self.width - width) - num_masked = mask[top : top + h, left : left + w].sum() + num_masked = mask[top : top + height, left : left + width].sum() # Overlap - if 0 < h * w - num_masked <= max_mask_patches: - for i in range(top, top + h): - for j in range(left, left + w): + if 0 < height * width - num_masked <= max_mask_patches: + for i in range(top, top + height): + for j in range(left, left + width): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 @@ -117,7 +117,7 @@ def __call__(self): return mask -class FLAVAFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): +class FlavaFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): r""" Constructs a FLAVA feature extractor. @@ -139,41 +139,86 @@ class FLAVAFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin) Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to `True`. do_normalize (`bool`, *optional*, defaults to `True`): Whether or not to normalize the input with `image_mean` and `image_std`. - image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`): + image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`): The sequence of means for each channel, to be used when normalizing images. - image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`): + image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`): The sequence of standard deviations for each channel, to be used when normalizing images. + input_size_patches (`int`, *optional*, defaults to 14): + Number of patches in the image in height and width direction. 14x14 = 196 total patches. + total_mask_patches (`int`, *optional*, defaults to 75): + Total number of patches that should be masked. + mask_group_min_patches (`int`, *optional*, defaults to 16): + Minimum number of patches that should be masked. + mask_group_max_patches (`int`, *optional*, defaults to None): + Maximum number of patches that should be masked. + mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3): + Minimum aspect ratio of the mask window. + mask_group_max_aspect_ratio (`float`, *optional*, defaults to None): + Maximum aspect ratio of the mask window + codebook_do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input for codebook to a certain `codebook_size`. + codebook_size (`int`, *optional*, defaults to 224): + Resize the input for codebook to the given size. Only has an effect if `codebook_do_resize` is set to + `True`. + codebook_resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + codebook_do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input for codebook at the center. If the input size is smaller than + `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. + codebook_crop_size (`int`, *optional*, defaults to 224): + Desired output size for codebook input when applying center-cropping. Only has an effect if + `codebook_do_center_crop` is set to `True`. + codebook_do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. + codebook_image_mean (`Tuple[float, float, float]`, *optional*, defaults to `[0, 0, 0]`): + The sequence of means for each channel, to be used when normalizing images for codebook. + codebook_image_std (`Tuple[float, float, float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images for codebook. + """ model_input_names = ["pixel_values"] def __init__( self, - do_resize=True, - size=224, - resample=None, - do_center_crop=True, - crop_size=224, - do_normalize=True, - image_mean=None, - image_std=None, + do_resize: bool = True, + size: Union[int, Tuple[int, int]] = 224, + resample: int = Image.BICUBIC, + do_center_crop: bool = True, + crop_size: Union[int, Tuple[int, int]] = 224, + do_normalize: bool = True, + image_mean: Tuple[float, float, float] = FLAVA_IMAGE_MEAN, + image_std: Tuple[float, float, float] = FLAVA_IMAGE_STD, + # Mask related params input_size_patches: int = 14, total_mask_patches: int = 75, - mask_group_min_patches: Optional[int] = 16, - mask_group_max_patches: int = None, + mask_group_min_patches: int = 16, + mask_group_max_patches: Optional[int] = None, mask_group_min_aspect_ratio: float = 0.3, mask_group_max_aspect_ratio: Optional[float] = None, - **kwargs, + # Codebook related params + codebook_do_resize: bool = True, + codebook_size: bool = 112, + codebook_resample: int = Image.LANCZOS, + codebook_do_center_crop: bool = True, + codebook_crop_size: int = 112, + codebook_do_map_pixels: bool = True, + codebook_do_normalize: bool = True, + codebook_image_mean: Tuple[float, float, float] = FLAVA_CODEBOOK_MEAN, + codebook_image_std: Tuple[float, float, float] = FLAVA_CODEBOOK_STD, + **kwargs: Any, ): super().__init__(**kwargs) self.do_resize = do_resize self.size = size - self.resample = resample if resample is not None else Image.BICUBIC + self.resample = resample self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN - self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD + self.image_mean = image_mean + self.image_std = image_std + self.input_size_patches = input_size_patches self.total_mask_patches = total_mask_patches self.mask_group_min_patches = mask_group_min_patches @@ -181,10 +226,20 @@ def __init__( self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio + self.codebook_do_resize = codebook_do_resize + self.codebook_size = codebook_size + self.codebook_resample = codebook_resample + self.codebook_do_center_crop = codebook_do_center_crop + self.codebook_crop_size = codebook_crop_size + self.codebook_do_map_pixels = codebook_do_map_pixels + self.codebook_do_normalize = codebook_do_normalize + self.codebook_image_mean = codebook_image_mean + self.codebook_image_std = codebook_image_std + @property @lru_cache() def masking_generator(self): - return MaskingGenerator( + return FlavaMaskingGenerator( input_size=self.input_size_patches, total_mask_patches=self.total_mask_patches, mask_group_min_patches=self.mask_group_min_patches, @@ -193,14 +248,18 @@ def masking_generator(self): mask_group_max_aspect_ratio=self.mask_group_max_aspect_ratio, ) + def map_pixels(self, x): + return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS + def __call__( self, images: Union[ Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa ], - return_masks: Optional[bool] = None, + return_image_mask: Optional[bool] = None, + return_codebook_pixels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs + **kwargs: Any ) -> BatchFeature: """ Main method to prepare for the model one or several image(s). @@ -218,9 +277,12 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_masks (`bool`, *optional*, defaults to None): + return_image_mask (`bool`, *optional*, defaults to None): If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version. + return_codebook_pixels (`bool`, *optional*, defaults to None): + If True, the processor will return `codebook_pixel_values` providing image pixels to be used with the + default FLAVA codebook. Used in pretraining by Masked Image Modeling (MIM) loss. return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): If set, will return tensors of a particular framework. Acceptable values are: @@ -249,6 +311,8 @@ def __call__( if not is_batched: images = [images] + images_for_codebook = images + # transformations (resizing + center cropping + normalization) if self.do_resize and self.size is not None and self.resample is not None: images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] @@ -259,106 +323,29 @@ def __call__( # return as BatchFeature data = {"pixel_values": images} - if return_masks: + if return_codebook_pixels: + images = images_for_codebook + if self.codebook_do_resize and self.codebook_size is not None and self.codebook_resample is not None: + images = [ + self.resize(image=image, size=self.codebook_size, resample=self.codebook_resample) + for image in images + ] + if self.codebook_do_center_crop and self.codebook_crop_size is not None: + images = [self.center_crop(image, self.codebook_crop_size) for image in images] + if self.codebook_do_normalize: + images = [ + self.normalize(image=image, mean=self.codebook_image_mean, std=self.codebook_image_std) + for image in images + ] + if self.codebook_do_map_pixels: + images = [self.map_pixels(image) for image in images] + + data["codebook_pixel_values"] = images + + if return_image_mask: masks = [self.masking_generator() for _ in images] data["bool_masked_pos"] = masks encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs - - -class FLAVACodebookFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): - def __init__( - self, - do_resize=True, - size=112, - resample=None, - do_center_crop=True, - crop_size=112, - do_map_pixels=True, - do_normalize=True, - image_mean=None, - image_std=None, - **kwargs, - ): - super().__init__(**kwargs) - self.do_resize = do_resize - self.size = size - self.resample = resample if resample is not None else Image.LANCZOS - self.do_center_crop = do_center_crop - self.do_map_pixels = do_map_pixels - self.crop_size = crop_size - self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else FLAVA_CODEBOOK_MEAN - self.image_std = image_std if image_std is not None else FLAVA_CODEBOOK_STD - - def map_pixels(self, x): - return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS - - def __call__( - self, - images: Union[ - Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa - ], - return_tensors: Optional[Union[str, TensorType]] = None, - **kwargs - ) -> BatchFeature: - """ - Main method to prepare for the model one or several image(s). - - - - NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass - PIL images. - - - - Args: - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a - number of channels, H and W are image height and width. - - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - - Returns: - [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - - **pixel_values** -- Pixel values to be fed to a model. - """ - # Input type checking for clearer error - if isinstance(images, (list, tuple)) and len(images) != 0: - self._ensure_format_supported(images[0]) - else: - self._ensure_format_supported(images) - - is_batched = bool( - isinstance(images, (list, tuple)) - and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) - ) - - if not is_batched: - images = [images] - - # transformations (resizing + center cropping + normalization) - if self.do_resize and self.size is not None and self.resample is not None: - images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] - if self.do_center_crop and self.crop_size is not None: - images = [self.center_crop(image, self.crop_size) for image in images] - if self.do_normalize: - images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] - if self.do_map_pixels: - images = [self.map_pixels(image) for image in images] - data = {"pixel_values": images} - - # return as BatchFeature - encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) - - return encoded_inputs diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 6d9e570e5d82d..0a00442ee42ca 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ import collections import math +import random from collections import OrderedDict from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -25,6 +26,8 @@ from packaging import version from torch import nn +from transformers import PreTrainedTokenizerBase +from transformers.tokenization_utils_base import BatchEncoding from transformers.utils.doc import add_code_sample_docstrings from ...activations import ACT2FN @@ -38,64 +41,60 @@ replace_return_docstrings, ) from .configuration_flava import ( - FLAVACodebookConfig, - FLAVAConfig, - FLAVAImageConfig, - FLAVAMultimodalConfig, - FLAVATextConfig, + FlavaConfig, + FlavaImageCodebookConfig, + FlavaImageConfig, + FlavaMultimodalConfig, + FlavaTextConfig, ) logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "aps/flava-full" +_CHECKPOINT_FOR_DOC = "facebook/flava-full" # Codebook docstring -_CHECKPOINT_FOR_CODEBOOK_DOC = "aps/flava-codebook" -_FEAT_EXTRACTOR_FOR_DOC = "FLAVAFeatureExtractor" -_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FLAVAImageConfig" -_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FLAVATextConfig" -_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FLAVAMultimodalConfig" +_CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook" +_FEAT_EXTRACTOR_FOR_DOC = "FlavaFeatureExtractor" +_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig" +_CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig" +_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" _EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768] FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "aps/flava-full", + "facebook/flava-full", # See all flava models at https://huggingface.co/models?filter=flava ] -FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST = ["aps/flava-codebook"] +FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/flava-image-codebook"] LOGIT_SCALE_CLAMP_MIN = 0 LOGIT_SCALE_CLAMP_MAX = 4.6052 -FLAVAPossibleConfigs = Union[FLAVATextConfig, FLAVAImageConfig, FLAVAMultimodalConfig] - - -def _build_codebook_conv2d(in_size: int, out_size: int, kernel_size: int): - return nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) +FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig] @dataclass -class FLAVAModelOutput(ModelOutput): +class FlavaModelOutput(ModelOutput): """ - Output from FLAVAModel containing embeddings and outputs from individual encoders. + Output from FlavaModel containing embeddings and outputs from individual encoders. Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and `text_projection` layers on `image_embeddings` and `text_embeddings` respectively. Args: - image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): - The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. - image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): - The output of the [`FLAVAImageModel`]. - text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present): - The text embeddings which are basically the pooled output of [`FLAVATextModel`]. - text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): - The output of the [`FLAVATextModel`]. - multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): - The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. - multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): - The output of the [`FLAVAMultimodalModel`]. + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. + image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. + text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FlavaTextModel`]. + multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`): + The output of the [`FlavaMultimodalModel`]. """ image_embeddings: Optional[torch.FloatTensor] = None @@ -113,23 +112,23 @@ def to_tuple(self) -> Tuple[Any]: @dataclass -class FLAVALosses(ModelOutput): +class FlavaLosses(ModelOutput): """Class representing pretraining losses from FLAVA model Args: - mim(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.: + mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.: Masked Image Modeling loss as used in BeIT calculated only for unimodal image data. - mlm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.: + mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.: Masked Language Modeling loss as used in BERT calculated only for unimodal text data. - itm(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.: + itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.: Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on masked pairs in FLAVA. - global_contrastive(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.: + global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.: Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text data. This is calculated on unmasked images and texts. - mmm_image(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.: + mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.: Masked Multimodal Modeling loss's image component calculated on paired image-text data. - mmm_text(`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.: + mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.: Masked Multimodal Modeling loss's text component calculated on paired image-text data. """ @@ -150,9 +149,9 @@ def all_none(self) -> bool: @dataclass -class FLAVAForPreTrainingOutput(ModelOutput): +class FlavaForPreTrainingOutput(ModelOutput): """ - Output from FLAVAForPreTraining containing embeddings, and outputs from individual encoders. + Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders. Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and @@ -161,73 +160,62 @@ class FLAVAForPreTrainingOutput(ModelOutput): Args: loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True): Total loss calculated for this model. - loss_info (`FLAVALosses`): - Detailed info for FLAVA Pretraining losses. Check `FLAVALosses` class description for the information on + loss_info (`FlavaLosses`): + Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on the keys. - image_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): - The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. - image_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): - The output of the [`FLAVAImageModel`]. - text_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` are present): - The text embeddings which are basically the pooled output of [`FLAVATextModel`]. - text_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): - The output of the [`FLAVATextModel`]. - multimodal_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): - The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. - multimodal_output(`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): - The output of the [`FLAVAMultimodalModel`]. - - image_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `pixel_values` are present): - The image embeddings which are basically the pooled output of [`FLAVAImageModel`]. Uses `bool_masked_pos` + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. + image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. + text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present): + The output of the [`FlavaTextModel`]. + multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`): + The output of the [`FlavaMultimodalModel`]. + + image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present): + The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images. - image_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): - The output of the [`FLAVAImageModel`]. Uses `bool_masked_pos` to create masked images. - text_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids_masked` are present): - The text embeddings which are basically the pooled output of [`FLAVATextModel`]. - text_masked_output(`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present): - The output of the [`FLAVATextModel`]. - multimodal_masked_embeddings(`torch.FloatTensor` of shape `(batch_size, output_dim`), *optional*, returned when `input_ids` and `pixel_values` are present): - The multimodal embeddings which are basically the pooled output of [`FLAVATextModel`]. - multimodal_masked_output(`BaseModelOutputWithPooling`, returned when `input_ids_masked` and `pixel_values` are present): - The output of the [`FLAVAMultimodalModel`]. - - - mim_logits: - (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape - `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and - `input_ids_masked` are not): The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked - patches. The flattened output is returned when `bool_masked_pos` has some of the patches masked. - mlm_logits: - (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape - `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and - `pixel_values` are not): The logits for MLM unimodal loss. The flattened output is returned when - `input_ids_masked` has some of the tokens masked. - itm_logits: - (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and - `pixel_values` are present): The logits for ITM loss. Note that ITM loss is calculated on masked pairs in - FLAVA. - mmm_image_logits: - (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape - `(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` - are present): The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The - flattened output is returned when `bool_masked_pos` has some of the patches masked. - mmm_text_logits: - (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape - `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and - `input_ids_masked` are present): The logits for MMM text multimodal loss. The flattened output is returned - when `input_ids_masked` has some of the tokens masked. - contrastive_logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present): + The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images. + text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present): + The text embeddings which are basically the pooled output of [`FlavaTextModel`]. + text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present): + The output of the [`FlavaTextModel`]. + multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present): + The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`]. + multimodal_masked_output (`BaseModelOutputWithPooling`, returned when `input_ids_masked` and `pixel_values` are present): + The output of the [`FlavaMultimodalModel`]. + + mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not): + The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is + returned when `bool_masked_pos` has some of the patches masked. + mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not): + The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of + the tokens masked. + itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present): + The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA. + mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present): + The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened + output is returned when `bool_masked_pos` has some of the patches masked. + mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present): + The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has + some of the tokens masked. + contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's `image_projection` and `text_projection` layers respectively. This represents the image-text similarity scores. This is calculated on unmasked images and texts. - contrastive_logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and texts. """ loss: Optional[torch.FloatTensor] = None - loss_info: FLAVALosses = None + loss_info: FlavaLosses = None image_embeddings: Optional[torch.FloatTensor] = None image_output: Optional[BaseModelOutputWithPooling] = None text_embeddings: Optional[torch.FloatTensor] = None @@ -260,26 +248,194 @@ def to_tuple(self) -> Tuple[Any]: return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys()) -# Inspired by -# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py -# From PyTorch internals -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) +def add_masked_text( + examples: BatchEncoding, + tokenizer: PreTrainedTokenizerBase, + mlm_probability: float = 0.15, + ce_ignore_index: int = -100, +) -> Dict[str, Any]: + """ + Adds `input_ids_masked` and `mlm_labels` keys to the given examples object based on the mask generation logic in + BERT paper. This function is to be used with FLAVA for unimodal Masked Language Modeling (MLM) pretraining loss + calculation. For multimodal MLM, please refer to `add_whole_word_masked_text` instead. + For the function to properly work, please set argument `return_special_tokens_mask=True` when calling the + [`FlavaProcessor`] for the example processing. -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py + Args: + examples ([`BatchEncoding`]): Examples object returned from [`FlavaProcessor`]. + tokenizer (PreTrainedTokenizerBase): + tokenizer used to tokenizer the text. Usually [`BertTokenizer`] for FLAVA. Can be accessed from + [`FlavaProcessor`] object through `tokenizer` attribute. + mlm_probability (float, *optional*, defaults to 0.15): Probability of MLM masking. + ce_ignore_index (int, *optional*, defaults to -100.): Index to ignore when calculating cross entropy loss. + """ + input_ids = examples["input_ids"] + special_tokens_mask = examples.pop("special_tokens_mask", None) + mlm_labels = input_ids.clone() + input_ids_masked = input_ids.clone() + + if mlm_probability > 0: + # We sample a few tokens in each sequence for MLM training (with probability `mlm_probability`) + probability_matrix = torch.full(mlm_labels.shape, mlm_probability, device=input_ids.device) + if special_tokens_mask is None: + special_tokens_mask = [ + tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in mlm_labels.tolist() + ] + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = special_tokens_mask.bool() + + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + mlm_labels[~masked_indices] = ce_ignore_index # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(mlm_labels.shape, 0.8)).bool() & masked_indices + input_ids_masked[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(mlm_labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), mlm_labels.shape, dtype=torch.long) + input_ids_masked[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + else: + if tokenizer.pad_token_id is not None: + mlm_labels[mlm_labels == tokenizer.pad_token_id] = ce_ignore_index + + examples["mlm_labels"] = mlm_labels + examples["input_ids_masked"] = input_ids_masked + + return examples + + +def add_whole_word_masked_text( + examples: BatchEncoding, + tokenizer: PreTrainedTokenizerBase, + mlm_probability: float = 0.15, + ce_ignore_index: int = -100, + max_predictions: int = 512, +) -> Dict[str, Any]: + """ + Adds `input_ids_masked` and `mlm_labels` keys to the given examples object based on the mask generation logic in + BERT paper. This function is to be used with FLAVA for multimodal Masked Language Modeling (MLM) pretraining loss + calculation which falls under Multimodal Masked Modeling loss. For unimodal MLM, please refer to `add_masked_text` + instead. + For the function to properly work, please set argument `return_special_tokens_mask=True` when calling the + [`FlavaProcessor`] for the example processing. -class FLAVAImageEmbeddings(nn.Module): + Args: + examples ([`BatchEncoding`]): Examples object returned from [`FlavaProcessor`]. + tokenizer (PreTrainedTokenizerBase): + tokenizer used to tokenizer the text. Usually [`BertTokenizer`] for FLAVA. Can be accessed from + [`FlavaProcessor`] object through `tokenizer` attribute. + mlm_probability (float, *optional*, defaults to 0.15): Probability of MLM masking. + ce_ignore_index (int, *optional*, defaults to -100.): Index to ignore when calculating cross entropy loss. + max_predictions (int, *optional*, defaults to 512): Maximum masked predictions that can happen. """ - Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + input_ids = examples["input_ids"].clone() + if mlm_probability > 0: + is_batched = True + if input_ids.dim() != 2: + is_batched = False + input_ids = input_ids.unsqueeze(0) + + mask_labels = [] + for current_input_ids in input_ids: + ref_tokens = [] + for idx in current_input_ids.tolist(): + ref_tokens.append(tokenizer._convert_id_to_token(idx)) + cand_indexes = [] + for (i, token) in enumerate(ref_tokens): + if token == "[CLS]" or token == "[SEP]": + continue + + if len(cand_indexes) >= 1 and token.startswith("##"): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + random.shuffle(cand_indexes) + num_to_predict = min(max_predictions, max(1, int(round(len(ref_tokens) * mlm_probability)))) + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_lms.append(index) + + if len(covered_indexes) != len(masked_lms): + raise ValueError("Length of covered_indexes is not equal to length of masked_lms.") + current_mask_labels = [1 if i in covered_indexes else 0 for i in range(len(ref_tokens))] + mask_labels.append(torch.tensor(current_mask_labels, device=input_ids.device)) + + mask_labels = torch.stack(mask_labels) + labels = input_ids.clone() + probability_matrix = mask_labels + special_tokens_mask = examples.pop("special_tokens_mask", None) + + if special_tokens_mask is None: + special_tokens_mask = [ + tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = special_tokens_mask.bool() + + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + if tokenizer._pad_token is not None: + padding_mask = labels.eq(tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) + + masked_indices = probability_matrix.bool() + labels[~masked_indices] = ce_ignore_index # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + input_ids[indices_replaced] = tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + input_ids[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + if not is_batched: + input_ids = input_ids[0] + labels = labels[0] + else: + labels = examples["input_ids"].clone() + input_ids = examples["input_ids"].clone() + if tokenizer.pad_token_id is not None: + labels[labels == tokenizer.pad_token_id] = ce_ignore_index + + examples["mlm_labels"] = labels + examples["input_ids_masked"] = input_ids + return examples + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py +class FlavaImageEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. """ - def __init__(self, config: FLAVAImageConfig, use_mask_token: bool = False) -> None: + def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None: super().__init__() use_mask_token = use_mask_token or config.mask_token @@ -306,24 +462,28 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: """ npatch = embeddings.shape[1] - 1 - N = self.position_embeddings.shape[1] - 1 - if npatch == N and height == width: + num_pos = self.position_embeddings.shape[1] - 1 + if npatch == num_pos and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size + num_h_patches = height // self.config.patch_size + num_w_patches = width // self.config.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 + num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2), + scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)), mode="bicubic", align_corners=False, ) - assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]: + raise ValueError( + f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " + f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})" + ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) @@ -366,7 +526,6 @@ def forward( class PatchEmbeddings(nn.Module): """ Image to Patch Embedding. - """ def __init__( @@ -377,8 +536,10 @@ def __init__( embed_dim: int = 768, ): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + if not isinstance(image_size, collections.abc.Iterable): + image_size = (image_size, image_size) + if not isinstance(patch_size, collections.abc.Iterable): + patch_size = (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size @@ -397,7 +558,7 @@ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = F return x -class FLAVATextEmbeddings(nn.Module): +class FlavaTextEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): @@ -455,8 +616,8 @@ def forward( return embeddings -class FLAVASelfAttention(nn.Module): - def __init__(self, config: FLAVAPossibleConfigs) -> None: +class FlavaSelfAttention(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -524,13 +685,13 @@ def forward( return outputs -class FLAVASelfOutput(nn.Module): +class FlavaSelfOutput(nn.Module): """ - The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the - layernorm applied before each block. + The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other + models), due to the layernorm applied before each block. """ - def __init__(self, config: FLAVAPossibleConfigs) -> None: + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -543,11 +704,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class FLAVAAttention(nn.Module): - def __init__(self, config: FLAVAPossibleConfigs) -> None: +class FlavaAttention(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() - self.attention = FLAVASelfAttention(config) - self.output = FLAVASelfOutput(config) + self.attention = FlavaSelfAttention(config) + self.output = FlavaSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads: Set[int]) -> None: @@ -585,8 +746,8 @@ def forward( return outputs -class FLAVAIntermediate(nn.Module): - def __init__(self, config: FLAVAPossibleConfigs) -> None: +class FlavaIntermediate(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -602,8 +763,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class FLAVAOutput(nn.Module): - def __init__(self, config: FLAVAPossibleConfigs) -> None: +class FlavaOutput(nn.Module): + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -617,16 +778,16 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -class FLAVALayer(nn.Module): +class FlavaLayer(nn.Module): """This corresponds to the Block class in the timm implementation.""" - def __init__(self, config: FLAVAPossibleConfigs) -> None: + def __init__(self, config: FlavaPossibleConfigs) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = FLAVAAttention(config) - self.intermediate = FLAVAIntermediate(config) - self.output = FLAVAOutput(config) + self.attention = FlavaAttention(config) + self.intermediate = FlavaIntermediate(config) + self.output = FlavaOutput(config) # TODO: Check fp32 layer norm possiblity self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -663,11 +824,11 @@ def forward( return outputs -class FLAVAEncoder(nn.Module): - def __init__(self, config: FLAVAConfig) -> None: +class FlavaEncoder(nn.Module): + def __init__(self, config: FlavaConfig) -> None: super().__init__() self.config = config - self.layer = nn.ModuleList([FLAVALayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -716,14 +877,12 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions ) -class FLAVAPooler(nn.Module): - def __init__(self, config: FLAVAPossibleConfigs): +class FlavaPooler(nn.Module): + def __init__(self, config: FlavaPossibleConfigs): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() @@ -773,9 +932,10 @@ def forward(self, hidden_states: torch.Tensor): """ FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r""" + Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`FLAVAFeatureExtractor`]. See - [`FLAVAFeatureExtractor.__call__`] for details. + Pixel values. Pixel values can be obtained using [`FlavaFeatureExtractor`]. See + [`FlavaFeatureExtractor.__call__`] for details. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -784,15 +944,10 @@ def forward(self, hidden_states: torch.Tensor): Whether to interpolate the pre-trained position encodings. """ -FLAVA_IMAGE_INPUTS_DOCSTRING = ( - r""" - Args: -""" - + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE - + FLAVA_INPUTS_DOCSTRING_COMMON -) +FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r""" + Args: input_ids (`torch.LongTensor` of shape `({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input @@ -805,13 +960,8 @@ def forward(self, hidden_states: torch.Tensor): - 1 corresponds to a *sentence B* token. [What are token type IDs?](../glossary#token-type-ids) """ -FLAVA_TEXT_INPUTS_DOCSTRING = ( - r""" - Args: -""" - + FLAVA_TEXT_INPUTS_DOCSTRING_BASE - + FLAVA_INPUTS_DOCSTRING_COMMON -) + +FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON FLAVA_MULTIMODAL_INPUTS_DOCSTRING = ( r""" @@ -822,17 +972,17 @@ def forward(self, hidden_states: torch.Tensor): + FLAVA_INPUTS_DOCSTRING_COMMON ) -FLAVA_MODEL_INPUTS_DOCSTRING = ( - r""" +FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r""" Args: -""" - + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE - + FLAVA_TEXT_INPUTS_DOCSTRING_BASE - + FLAVA_INPUTS_DOCSTRING_COMMON - + r""" skip_multimodal_encoder (*bool*, *optional*): Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used. """ + +FLAVA_MODEL_INPUTS_DOCSTRING = ( + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + + FLAVA_TEXT_INPUTS_DOCSTRING_BASE + + FLAVA_INPUTS_DOCSTRING_COMMON + + FLAVA_MODEL_INPUTS_DOCSTRING_BASE ) @@ -862,15 +1012,16 @@ def forward(self, hidden_states: torch.Tensor): mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*): Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction). - Indices should be in `[-100, 0, ..., text_config.vocab_size]` (see `input_ids` docstring) Tokens with - indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels n `[0, - ..., text_config.vocab_size]` + Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with + indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, + ..., text_config.vocab_size - 1]`. mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*): Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ..., - image_config.vocab_size]`. Tokens with indices set to `-100` are ignored (masked), the loss is only - computed for the tokens with labels n `[0, ..., image_config.vocab_size]`. See [`FLAVACodebook`] to - understand how to generate mim_labels. + image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only + computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are + generated automatically using the image codebook assigned to the model. By default, it uses + [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels. itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. @@ -882,14 +1033,20 @@ def forward(self, hidden_states: torch.Tensor): + FLAVA_INPUTS_DOCSTRING_COMMON ) +FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r""" + Parameters: + image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will + be initialized using the image_codebook_config defined in the config first as the first parameter. +""" -class FLAVAPreTrainedModel(PreTrainedModel): + +class FlavaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = FLAVAConfig + config_class = FlavaConfig base_model_prefix = "flava" supports_gradient_checkpointing = True @@ -909,30 +1066,30 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FLAVAEncoder, value: bool = False) -> None: - if isinstance(module, FLAVAEncoder): + def _set_gradient_checkpointing(self, module: FlavaEncoder, value: bool = False) -> None: + if isinstance(module, FlavaEncoder): module.gradient_checkpointing = value @add_start_docstrings( "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.", - FLAVA_START_DOCSTRING.format(config="FLAVAImageConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"), ) -class FLAVAImageModel(FLAVAPreTrainedModel): - config_class = FLAVAImageConfig +class FlavaImageModel(FlavaPreTrainedModel): + config_class = FlavaImageConfig base_model_prefix = "flava.image_model" main_input_name = "pixel_values" - def __init__(self, config: FLAVAImageConfig, add_pooling_layer: bool = True): + def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True): super().__init__(config) self.config = config - self.embeddings = FLAVAImageEmbeddings(config) - self.encoder = FLAVAEncoder(config) + self.embeddings = FlavaImageEmbeddings(config) + self.encoder = FlavaEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = FLAVAPooler(config) if add_pooling_layer else None + self.pooler = FlavaPooler(config) if add_pooling_layer else None self.post_init() @@ -1015,21 +1172,21 @@ def forward( @add_start_docstrings( "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.", - FLAVA_START_DOCSTRING.format(config="FLAVATextConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"), ) -class FLAVATextModel(FLAVAPreTrainedModel): - config_class = FLAVATextConfig +class FlavaTextModel(FlavaPreTrainedModel): + config_class = FlavaTextConfig base_model_prefix = "flava.text_model" - def __init__(self, config: FLAVATextConfig, add_pooling_layer: bool = True): + def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True): super().__init__(config) self.config = config - self.embeddings = FLAVATextEmbeddings(config) - self.encoder = FLAVAEncoder(config) + self.embeddings = FlavaTextEmbeddings(config) + self.encoder = FlavaEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = FLAVAPooler(config) if add_pooling_layer else None + self.pooler = FlavaPooler(config) if add_pooling_layer else None self.post_init() @@ -1120,24 +1277,24 @@ def forward( @add_start_docstrings( "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.", - FLAVA_START_DOCSTRING.format(config="FLAVAMultimodalConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"), ) -class FLAVAMultimodalModel(FLAVAPreTrainedModel): - config_class = FLAVAMultimodalConfig +class FlavaMultimodalModel(FlavaPreTrainedModel): + config_class = FlavaMultimodalConfig base_model_prefix = "flava.multimodal_model" main_input_name = "hidden_states" - def __init__(self, config: FLAVAMultimodalConfig, add_pooling_layer=True): + def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True): super().__init__(config) self.config = config self.use_cls_token = self.config.use_cls_token if self.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.encoder = FLAVAEncoder(config) + self.encoder = FlavaEncoder(config) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = FLAVAPooler(config) if add_pooling_layer else None + self.pooler = FlavaPooler(config) if add_pooling_layer else None self.post_init() @@ -1218,27 +1375,27 @@ def forward( @add_start_docstrings( "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.", - FLAVA_START_DOCSTRING.format(config="FLAVAConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaConfig"), ) -class FLAVAModel(FLAVAPreTrainedModel): - config_class = FLAVAConfig +class FlavaModel(FlavaPreTrainedModel): + config_class = FlavaConfig - def __init__(self, config: FLAVAConfig): + def __init__(self, config: FlavaConfig): super().__init__(config) - if not isinstance(config.text_config, FLAVATextConfig): + if not isinstance(config.text_config, FlavaTextConfig): raise ValueError( - f"config.text_config is expected to be of type FLAVATextConfig but is of type {type(config.text_config)}." + f"config.text_config is expected to be of type FlavaTextConfig but is of type {type(config.text_config)}." ) - if not isinstance(config.image_config, FLAVAImageConfig): + if not isinstance(config.image_config, FlavaImageConfig): raise ValueError( - f"config.image_config is expected to be of type FLAVAImageConfig but is of type {type(config.image_config)}." + f"config.image_config is expected to be of type FlavaImageConfig but is of type {type(config.image_config)}." ) - if not isinstance(config.multimodal_config, FLAVAMultimodalConfig): + if not isinstance(config.multimodal_config, FlavaMultimodalConfig): raise ValueError( - "config.multimodal_config is expected to be of type FLAVAMultimodalConfig but " + "config.multimodal_config is expected to be of type FlavaMultimodalConfig but " + f"is of type {type(config.multimodal_config)}." ) @@ -1251,9 +1408,9 @@ def __init__(self, config: FLAVAConfig): self.image_hidden_size = image_config.hidden_size self.mm_hidden_size = multimodal_config.hidden_size - self.text_model = FLAVATextModel(text_config) - self.image_model = FLAVAImageModel(image_config) - self.multimodal_model = FLAVAMultimodalModel(multimodal_config) + self.text_model = FlavaTextModel(text_config) + self.image_model = FlavaImageModel(image_config) + self.multimodal_model = FlavaMultimodalModel(multimodal_config) self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim) self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim) @@ -1278,15 +1435,15 @@ def get_text_features( r""" Returns: text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by - applying the projection layer to the pooled output of [`FLAVATextModel`]. + applying the projection layer to the pooled output of [`FlavaTextModel`]. Examples: ```python - >>> from transformers import FLAVAProcessor, FLAVAModel + >>> from transformers import FlavaProcessor, FlavaModel - >>> model = FLAVAModel.from_pretrained("{0}") - >>> processor = FLAVAProcessor.from_pretrained("{0}") + >>> model = FlavaModel.from_pretrained("{0}") + >>> processor = FlavaProcessor.from_pretrained("{0}") >>> inputs = processor( ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt" @@ -1325,17 +1482,17 @@ def get_image_features( r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by - applying the projection layer to the pooled output of [`FLAVAImageModel`]. + applying the projection layer to the pooled output of [`FlavaImageModel`]. Examples: ```python >>> from PIL import Image >>> import requests - >>> from transformers import FLAVAProcessor, FLAVAModel + >>> from transformers import FlavaProcessor, FlavaModel - >>> model = FLAVAModel.from_pretrained("{0}") - >>> processor = FLAVAProcessor.from_pretrained("{0}") + >>> model = FlavaModel.from_pretrained("{0}") + >>> processor = FlavaProcessor.from_pretrained("{0}") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -1365,7 +1522,7 @@ def get_image_features( @add_start_docstrings_to_model_forward( FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len") ) - @replace_return_docstrings(output_type=FLAVAModelOutput, config_class=FLAVAConfig) + @replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1379,7 +1536,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: bool = True, return_dict: Optional[bool] = None, - ) -> Union[Tuple, FLAVAOutput]: + ) -> Union[Tuple, FlavaOutput]: r""" Returns: @@ -1388,10 +1545,10 @@ def forward( ```python >>> from PIL import Image >>> import requests - >>> from transformers import FLAVAProcessor, FLAVAModel + >>> from transformers import FlavaProcessor, FlavaModel - >>> model = FLAVAModel.from_pretrained("aps/flava-full") - >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full") + >>> model = FlavaModel.from_pretrained("facebook/flava-full") + >>> processor = FlavaProcessor.from_pretrained("facebook/flava-full") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -1405,7 +1562,8 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.return_dict - assert output_hidden_states is True, "FLAVA model requires hidden states to work." + if not output_hidden_states: + raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`") image_embeddings = None image_states = None image_mm_projection = None @@ -1459,7 +1617,7 @@ def forward( multimodal_output, ) - return FLAVAModelOutput( + return FlavaModelOutput( image_embeddings=image_embeddings, image_output=image_output, text_embeddings=text_embeddings, @@ -1469,55 +1627,81 @@ def forward( ) -class FLAVACodebookBlock(nn.Module): +class FlavaImageCodebookResPath(nn.Module): + def __init__(self, in_size: int, out_size: int, **kwargs): + super().__init__() + hid_size = out_size // 4 + + path = OrderedDict() + path["relu_1"] = nn.ReLU() + path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1) + path["relu_2"] = nn.ReLU() + path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) + path["relu_3"] = nn.ReLU() + path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1) + path["relu_4"] = nn.ReLU() + path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0) + + self.path = nn.Sequential(path) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.path(x) + + +class FlavaImageCodebookBlock(nn.Module): def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs): super().__init__() - n_hid = out_size // 4 self.post_gain = 1 / (num_layers**2) if in_size != out_size: - self.id_path = _build_codebook_conv2d(in_size, out_size, kernel_size=1) + self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0) else: self.id_path = nn.Identity() - self.res_path = nn.Sequential( - OrderedDict( - [ - ("relu_1", nn.ReLU()), - ("conv_1", _build_codebook_conv2d(in_size, n_hid, kernel_size=3)), - ("relu_2", nn.ReLU()), - ("conv_2", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)), - ("relu_3", nn.ReLU()), - ("conv_3", _build_codebook_conv2d(n_hid, n_hid, kernel_size=3)), - ("relu_4", nn.ReLU()), - ("conv_4", _build_codebook_conv2d(n_hid, out_size, kernel_size=1)), - ] - ) - ) + self.res_path = FlavaImageCodebookResPath(in_size, out_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.id_path(x) + self.post_gain * self.res_path(x) +class FlavaImageCodebookLayerGroup(nn.Module): + def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True): + super().__init__() + blocks = OrderedDict() + for i in range(num_blocks): + if i == 0: + blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers) + else: + blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers) + + if use_pool: + blocks["pool"] = nn.MaxPool2d(kernel_size=2) + + self.group = nn.Sequential(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.group(x) + + # Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42 @add_start_docstrings( """ - The FLAVA's codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used to - generate image tokens for an image based on DALL-E's vocab. To be used to generate labels for MIM. Use + The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used + to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use `get_codebook_indices` to get image tokens for an image. """, - FLAVA_START_DOCSTRING.format(config="FLAVACodebookConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"), ) -class FLAVACodebook(FLAVAPreTrainedModel): +class FlavaImageCodebook(FlavaPreTrainedModel): base_model_prefix = "" - config_class = FLAVACodebookConfig + config_class = FlavaImageCodebookConfig main_input_name = "pixel_values" supports_gradient_checkpointing = False def __init__( self, - config: FLAVACodebookConfig, + config: FlavaImageCodebookConfig, **kwargs: Any, ): super().__init__(config) @@ -1530,96 +1714,56 @@ def __init__( self.vocab_size = config.vocab_size num_layers = self.num_groups * self.num_blocks_per_group - output_conv = _build_codebook_conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1) - - self.blocks = nn.Sequential( - OrderedDict( - [ - ("input", _build_codebook_conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7)), - ( - "group_1", - self._create_group( - num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 1 * self.hidden_size - ), - ), - ( - "group_2", - self._create_group( - num_layers, self.num_blocks_per_group, 1 * self.hidden_size, 2 * self.hidden_size - ), - ), - ( - "group_3", - self._create_group( - num_layers, self.num_blocks_per_group, 2 * self.hidden_size, 4 * self.hidden_size - ), - ), - ( - "group_4", - self._create_group( - num_layers, - self.num_blocks_per_group, - 4 * self.hidden_size, - 8 * self.hidden_size, - use_pool=False, - ), - ), - ( - "output", - nn.Sequential(OrderedDict([("relu", nn.ReLU()), ("conv", output_conv)])), - ), - ] - ) - ) - self.post_init() - if self.config.freeze: - self._freeze() - - def _freeze(self): - for param in self.parameters(): - param.requires_grad = False + output_blocks = OrderedDict() + output_blocks["relu"] = nn.ReLU() + output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0) - def _create_group( - self, - num_layers: int, - num_blocks_per_group: int, - in_size: int, - hidden_size: int, - use_pool: bool = True, - ): blocks = OrderedDict() - for i in range(num_blocks_per_group): - if i == 0: - blocks[f"block_{i+1}"] = FLAVACodebookBlock(in_size, hidden_size, num_layers) - else: - blocks[f"block_{i+1}"] = FLAVACodebookBlock(hidden_size, hidden_size, num_layers) + blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3) + blocks["group_1"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size + ) + blocks["group_2"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size + ) + blocks["group_3"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size + ) + blocks["group_4"] = FlavaImageCodebookLayerGroup( + self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False + ) + blocks["output"] = nn.Sequential(output_blocks) - if use_pool: - blocks["pool"] = nn.MaxPool2d(kernel_size=2) + self.blocks = nn.Sequential(blocks) - return nn.Sequential(blocks) + self.post_init() + + if self.config.freeze: + for param in self.parameters(): + param.requires_grad = False def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See - [`FLAVACodebookFeatureExtractor.__call__`] for details. + Pixel values. Codebook pixel values can be obtained using [`FlavaFeatureExtractor`] by passing + `return_codebook_pixels=True`. See [`FlavaFeatureExtractor.__call__`] for details. Examples: ```python >>> from PIL import Image >>> import requests - >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook + >>> from transformers import FlavaFeatureExtractor, FlavaImageCodebook - >>> model = FLAVAModel.from_pretrained("{0}") - >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}") + >>> model = FlavaImageCodebook.from_pretrained("{0}") + >>> feature_extractor = FlavaFeatureExtractor.from_pretrained("{0}") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = feature_extractor(image, return_mask=True, return_tensors="pt") + >>> inputs = feature_extractor([image], return_codebook_pixels=True, return_tensors="pt") + >>> inputs = dict(pixel_values=inputs.codebook_pixel_values) >>> outputs = model.get_codebook_indices(**inputs) ``` @@ -1637,23 +1781,24 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Pixel values can be obtained using [`FLAVACodebookFeatureExtractor`]. See - [`FLAVACodebookFeatureExtractor.__call__`] for details. + Pixel values. Codebook pixel values can be obtained using [`FlavaFeatureExtractor`] by passing + `return_codebook_pixels=True`. See [`FlavaFeatureExtractor.__call__`] for details. Examples: ```python >>> from PIL import Image >>> import requests - >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook + >>> from transformers import FlavaFeatureExtractor, FlavaImageCodebook - >>> model = FLAVAModel.from_pretrained("{0}") - >>> feature_extractor = FLAVACodebookFeaturExtractor.from_pretrained("{0}") + >>> model = FlavaImageCodebook.from_pretrained("{0}") + >>> feature_extractor = FlavaFeatureExtractor.from_pretrained("{0}") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = feature_extractor([image], return_tensors="pt") + >>> inputs = feature_extractor([image], return_codebook_pixels=True, return_tensors="pt") + >>> inputs = dict(pixel_values=inputs.codebook_pixel_values) >>> outputs = model(**inputs) >>> print(outputs.shape) @@ -1669,7 +1814,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return self.blocks(pixel_values) -class FLAVAPredictionHeadTransform(nn.Module): +class FlavaPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -1686,11 +1831,11 @@ def forward(self, hidden_states): return hidden_states -class FLAVAMaskedPredictionHead(nn.Module): +class FlavaMaskedPredictionHead(nn.Module): def __init__(self, config, weight=None): super().__init__() self.config = config - self.transform = FLAVAPredictionHeadTransform(config) + self.transform = FlavaPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) if weight is not None: @@ -1705,11 +1850,11 @@ def forward(self, x): return x -class FLAVAITMHead(nn.Module): +class FlavaITMHead(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.pooler = FLAVAPooler(config) + self.pooler = FlavaPooler(config) self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, x): @@ -1718,7 +1863,7 @@ def forward(self, x): return x -class FLAVAGlobalContrastiveHead(nn.Module): +class FlavaGlobalContrastiveHead(nn.Module): def __init__(self, config): super().__init__() self.config = config @@ -1760,21 +1905,25 @@ def forward(self, image_embeddings, text_embeddings, logit_scale): """ The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs. """, - FLAVA_START_DOCSTRING.format(config="FLAVAConfig"), + FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA, ) -class FLAVAForPreTraining(FLAVAPreTrainedModel): - def __init__(self, config: FLAVAConfig): +class FlavaForPreTraining(FlavaPreTrainedModel): + def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None): super().__init__(config) - self.flava = FLAVAModel(config) + self.flava = FlavaModel(config) + + self.image_codebook = image_codebook + if self.image_codebook is None and config.init_codebook: + self.image_codebook = FlavaImageCodebook(config.image_codebook_config) # Levarage text and image encoder configs to create the masked # head since it has the right vocab - self.mim_head = FLAVAMaskedPredictionHead(config.image_config) - self.mlm_head = FLAVAMaskedPredictionHead(config.text_config) - self.itm_head = FLAVAITMHead(config) - self.mmm_image_head = FLAVAMaskedPredictionHead(config.image_config) - self.mmm_text_head = FLAVAMaskedPredictionHead(config.text_config) - self.global_contrastive_head = FLAVAGlobalContrastiveHead(config) + self.mim_head = FlavaMaskedPredictionHead(config.image_config) + self.mlm_head = FlavaMaskedPredictionHead(config.text_config) + self.itm_head = FlavaITMHead(config) + self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config) + self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config) + self.global_contrastive_head = FlavaGlobalContrastiveHead(config) self.image_vocab_size = config.image_config.vocab_size self.text_vocab_size = config.text_config.vocab_size @@ -1797,12 +1946,13 @@ def _resize_to_2d(self, x: torch.Tensor): @add_start_docstrings_to_model_forward( FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches") ) - @replace_return_docstrings(output_type=FLAVAForPreTrainingOutput, config_class=FLAVAConfig) + @replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, input_ids_masked: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, + codebook_pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, bool_masked_pos: Optional[torch.Tensor] = None, @@ -1822,27 +1972,26 @@ def forward( ```python >>> from PIL import Image >>> import requests - >>> from transformers import FLAVACodebookFeatureExtractor, FLAVACodebook, FLAVAForPreTraining, FLAVAProcessor - - >>> codebook = FLAVACodebook.from_pretrained("aps/flava-codebook") - >>> codebook_feature_extractor = FLAVACodebookFeatureExtractor.from_pretrained("aps/flava-codebook") + >>> from transformers import FlavaForPreTraining, FlavaProcessor >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = codebook_feature_extractor([image], return_tensors="pt") - - >>> mim_labels = codebook.get_codebook_indices(**inputs) - - >>> model = FLAVAForPreTraining.from_pretrained("aps/flava-full") - >>> processor = FLAVAProcessor.from_pretrained("aps/flava-full") + >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full") + >>> processor = FlavaProcessor.from_pretrained("facebook/flava-full") >>> text = ["a photo of a cat"] >>> inputs = processor( - ... images=[image], text=text, return_masks=True, padding=True, max_length=77, return_tensors="pt" + ... images=[image], + ... text=text, + ... return_masks=True, + ... return_codebook_pixels=True, + ... padding=True, + ... max_length=77, + ... return_tensors="pt", ... ) - >>> inputs["mim_labels"] = mim_labels + >>> output = model(**inputs) ``` @@ -1859,6 +2008,13 @@ def forward( else self.skip_unmasked_multimodal_encoder ) + if input_ids_masked is None and input_ids is not None: + logger.warning( + "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctly" + "Setting it to `input_ids` so that model can work. Please pass it if this is unintentional..." + ) + input_ids_masked = input_ids + flava_output = self.flava( input_ids=input_ids, pixel_values=pixel_values, @@ -1894,15 +2050,31 @@ def forward( image_masked_embeddings = flava_masked_output.image_embeddings text_masked_embeddings = flava_masked_output.text_embeddings multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings - total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None - mim_logits = ( - mlm_logits - ) = mmm_text_logits = mmm_image_logits = itm_logits = logits_per_image = logits_per_text = None + total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None + mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None + itm_logits = logits_per_image = logits_per_text = None + + # Calculate mim_labels if necessary from the image_codebook + if image_masked_embeddings is not None or multimodal_masked_embeddings is not None: + if mim_labels is None and return_loss: + if self.image_codebook is None: + raise RuntimeError( + "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` " + " have been passed. Reinstantiate the model with `init_codebook` set to True or " + "pass in your custom `mim_labels`" + ) + if codebook_pixel_values is None: + raise ValueError( + "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. " + "Call `FlavaProcessor` with `return_codebook_pixels` set to True" + ) + mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values) # Unimodal MIM Loss # If multimodal embeddings are present, we will calculate MMM loss if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None: sequence_for_image = image_masked_embeddings + if mim_labels is not None: mim_labels = self._resize_to_2d(mim_labels) bool_masked_pos = self._resize_to_2d(bool_masked_pos) @@ -2031,7 +2203,7 @@ def forward( gc_loss = (gc_loss_image + gc_loss_text) / 2 gc_loss *= self.global_contrastive_weight - flava_losses = FLAVALosses( + flava_losses = FlavaLosses( mim=mim_loss, mlm=mlm_loss, itm=itm_loss, @@ -2076,7 +2248,7 @@ def forward( # Filter None as transformer by default won't handle it return tuple(x for x in output if x is None) - return FLAVAForPreTrainingOutput( + return FlavaForPreTrainingOutput( loss=total_loss, loss_info=flava_losses, image_embeddings=image_embeddings, diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py index 062d2c85b5925..15489f05d02e3 100644 --- a/src/transformers/models/flava/processing_flava.py +++ b/src/transformers/models/flava/processing_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 Meta Platforms authors and The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,33 +19,28 @@ import numpy as np -from transformers.data.data_collator import DataCollatorForWholeWordMask, tolist - from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from ...utils import TensorType -class FLAVAProcessor(ProcessorMixin): +class FlavaProcessor(ProcessorMixin): r""" Constructs a FLAVA processor which wraps a FLAVA feature extractor and a FLAVA tokenizer into a single processor. - [`FLAVAProcessor`] offers all the functionalities of [`FLAVAFeatureExtractor`] and [`FLAVATokenizerFast`]. See the - [`~FLAVAProcessor.__call__`] and [`~FLAVAProcessor.decode`] for more information. + [`FlavaProcessor`] offers all the functionalities of [`FlavaFeatureExtractor`] and [`BertTokenizerFast`]. See the + [`~FlavaProcessor.__call__`] and [`~FlavaProcessor.decode`] for more information. Args: - feature_extractor ([`FLAVAFeatureExtractor`]): - The feature extractor is a required input. - tokenizer ([`FLAVATokenizerFast`]): - The tokenizer is a required input. + feature_extractor ([`FlavaFeatureExtractor`]): The feature extractor is a required input. + tokenizer ([`BertTokenizerFast`]): The tokenizer is a required input. """ - feature_extractor_class = "FLAVAFeatureExtractor" + feature_extractor_class = "FlavaFeatureExtractor" tokenizer_class = ("BertTokenizer", "BertTokenizerFast") - def __init__(self, feature_extractor, tokenizer, mlm_probability=0.15): + def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) self.current_processor = self.feature_extractor - self.text_masker = DataCollatorForWholeWordMask(tokenizer, mlm=True, mlm_probability=mlm_probability) def __call__( self, @@ -66,7 +61,8 @@ def __call__( max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, - return_masks: Optional[bool] = None, + return_image_mask: Optional[bool] = None, + return_codebook_pixels: Optional[bool] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, @@ -81,13 +77,7 @@ def __call__( This method uses [`FLAVAFeatureExtractor.__call__`] method to prepare image(s) for the model, and [`BertTokenizerFast.__call__`] to prepare text for the model. - Please refer to the docstring of the above two methods for more information. Other special args are mentioned - below: - - Args: - return_mask (`bool`, *optional*, defaults to None): - If True, the processor will return `bool_masked_pos` suggesting masks for image's patch version and - `input_ids_masked` and `mlm_labels` for MLM. + Please refer to the docstring of the above two methods for more information. """ if text is None and images is None: @@ -105,7 +95,7 @@ def __call__( return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask or return_masks, + return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, @@ -114,15 +104,13 @@ def __call__( ) if images is not None: image_features = self.feature_extractor( - images, return_masks=return_masks, return_tensors=return_tensors, **kwargs + images, + return_image_mask=return_image_mask, + return_codebook_pixels=return_codebook_pixels, + return_tensors=return_tensors, + **kwargs, ) - if return_masks and text is not None: - batch_masked = self.text_masker(tolist(encoding["input_ids"]), return_tensors=return_tensors) - encoding["input_ids_masked"] = batch_masked["input_ids"] - encoding["mlm_labels"] = batch_masked["labels"] - encoding.pop("special_tokens_mask") - if text is not None and images is not None: encoding.update(image_features) return encoding @@ -133,14 +121,14 @@ def __call__( def batch_decode(self, *args, **kwargs): """ - This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ - This method forwards all its arguments to FLAVATokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5fda47525a745..487eb6b2dafde 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1783,49 +1783,49 @@ def __init__(self, *args, **kwargs): FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None -class FLAVACodebook(metaclass=DummyObject): +class FlavaForPreTraining(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVAForPreTraining(metaclass=DummyObject): +class FlavaImageCodebook(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVAImageModel(metaclass=DummyObject): +class FlavaImageModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVAModel(metaclass=DummyObject): +class FlavaModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVAMultimodalModel(metaclass=DummyObject): +class FlavaMultimodalModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVAPreTrainedModel(metaclass=DummyObject): +class FlavaPreTrainedModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -class FLAVATextModel(metaclass=DummyObject): +class FlavaTextModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index f59dc74766975..b9c7d3e97f606 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -59,21 +59,14 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) -class FLAVACodebookFeatureExtractor(metaclass=DummyObject): +class FlavaFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) -class FLAVAFeatureExtractor(metaclass=DummyObject): - _backends = ["vision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["vision"]) - - -class FLAVAProcessor(metaclass=DummyObject): +class FlavaProcessor(metaclass=DummyObject): _backends = ["vision"] def __init__(self, *args, **kwargs): diff --git a/tests/flava/__init__.py b/tests/models/flava/__init__.py similarity index 100% rename from tests/flava/__init__.py rename to tests/models/flava/__init__.py diff --git a/tests/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py similarity index 68% rename from tests/flava/test_feature_extraction_flava.py rename to tests/models/flava/test_feature_extraction_flava.py index 4067b3d9c829b..793aa913aeb04 100644 --- a/tests/flava/test_feature_extraction_flava.py +++ b/tests/models/flava/test_feature_extraction_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 HuggingFace Inc. +# Copyright 2022 Meta Platforms authors and HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available -from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs +from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs if is_torch_available(): @@ -30,7 +30,7 @@ if is_vision_available(): from PIL import Image - from transformers import FLAVACodebookFeatureExtractor, FLAVAFeatureExtractor + from transformers import FlavaFeatureExtractor from transformers.models.flava.feature_extraction_flava import ( FLAVA_CODEBOOK_MEAN, FLAVA_CODEBOOK_STD, @@ -41,8 +41,7 @@ FLAVA_IMAGE_MEAN = FLAVA_IMAGE_STD = FLAVA_CODEBOOK_MEAN = FLAVA_CODEBOOK_STD = None -# TODO(aps): Add joint feature extractor useful for pretraining -class FLAVAFeatureExtractionTester(unittest.TestCase): +class FlavaFeatureExtractionTester(unittest.TestCase): def __init__( self, parent, @@ -64,6 +63,15 @@ def __init__( mask_group_min_patches=16, mask_group_min_aspect_ratio=0.3, mask_group_max_aspect_ratio=None, + codebook_do_resize=True, + codebook_size=112, + codebook_resample=None, + codebook_do_center_crop=True, + codebook_crop_size=112, + codebook_do_map_pixels=True, + codebook_do_normalize=True, + codebook_image_mean=FLAVA_CODEBOOK_MEAN, + codebook_image_std=FLAVA_CODEBOOK_STD, ): self.parent = parent self.batch_size = batch_size @@ -78,6 +86,7 @@ def __init__( self.image_std = image_std self.do_center_crop = do_center_crop self.crop_size = crop_size + self.input_size_patches = input_size_patches self.total_mask_patches = total_mask_patches self.mask_group_max_patches = mask_group_max_patches @@ -85,6 +94,16 @@ def __init__( self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio + self.codebook_do_resize = codebook_do_resize + self.codebook_size = codebook_size + self.codebook_resample = codebook_resample if codebook_resample is not None else Image.LANCZOS + self.codebook_do_center_crop = codebook_do_center_crop + self.codebook_crop_size = codebook_crop_size + self.codebook_do_map_pixels = codebook_do_map_pixels + self.codebook_do_normalize = codebook_do_normalize + self.codebook_image_mean = codebook_image_mean + self.codebook_image_std = codebook_image_std + def prepare_feat_extract_dict(self): return { "image_mean": self.image_mean, @@ -101,6 +120,15 @@ def prepare_feat_extract_dict(self): "mask_group_min_patches": self.mask_group_min_patches, "mask_group_min_aspect_ratio": self.mask_group_min_aspect_ratio, "mask_group_max_aspect_ratio": self.mask_group_min_aspect_ratio, + "codebook_do_resize": self.codebook_do_resize, + "codebook_size": self.codebook_size, + "codebook_resample": self.codebook_resample, + "codebook_do_center_crop": self.codebook_do_center_crop, + "codebook_crop_size": self.codebook_crop_size, + "codebook_do_map_pixels": self.codebook_do_map_pixels, + "codebook_do_normalize": self.codebook_do_normalize, + "codebook_image_mean": self.codebook_image_mean, + "codebook_image_std": self.codebook_image_std, } def get_expected_image_size(self): @@ -113,15 +141,22 @@ def get_expected_mask_size(self): else self.input_size_patches ) + def get_expected_codebook_image_size(self): + if not isinstance(self.codebook_size, tuple): + return (self.codebook_size, self.codebook_size) + else: + return self.codebook_size + @require_torch @require_vision -class FLAVAFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): +class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): - feature_extraction_class = FLAVAFeatureExtractor if is_vision_available() else None + feature_extraction_class = FlavaFeatureExtractor if is_vision_available() else None + maxDiff = None def setUp(self): - self.feature_extract_tester = FLAVAFeatureExtractionTester(self) + self.feature_extract_tester = FlavaFeatureExtractionTester(self) @property def feat_extract_dict(self): @@ -137,6 +172,15 @@ def test_feat_extract_properties(self): self.assertTrue(hasattr(feature_extractor, "crop_size")) self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "masking_generator")) + self.assertTrue(hasattr(feature_extractor, "codebook_do_resize")) + self.assertTrue(hasattr(feature_extractor, "codebook_size")) + self.assertTrue(hasattr(feature_extractor, "codebook_resample")) + self.assertTrue(hasattr(feature_extractor, "codebook_do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "codebook_crop_size")) + self.assertTrue(hasattr(feature_extractor, "codebook_do_map_pixels")) + self.assertTrue(hasattr(feature_extractor, "codebook_do_normalize")) + self.assertTrue(hasattr(feature_extractor, "codebook_image_mean")) + self.assertTrue(hasattr(feature_extractor, "codebook_image_std")) def test_batch_feature(self): pass @@ -196,7 +240,7 @@ def _test_call_framework(self, instance_class, prepare_kwargs): (1, self.feature_extract_tester.num_channels, expected_height, expected_width), ) - encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt") + encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt") expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() self.assertEqual( @@ -234,7 +278,7 @@ def _test_call_framework(self, instance_class, prepare_kwargs): ) # Test masking - encoded_images = feature_extractor(image_inputs, return_masks=True, return_tensors="pt") + encoded_images = feature_extractor(image_inputs, return_image_mask=True, return_tensors="pt") expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() self.assertEqual( @@ -261,12 +305,7 @@ def test_call_numpy(self): self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True}) def test_call_pytorch(self): - self._test_call_framework( - torch.Tensor, - prepare_kwargs={ - "torchify": True, - }, - ) + self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True}) def test_masking(self): # Initialize feature_extractor @@ -275,87 +314,10 @@ def test_masking(self): image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) # Test not batched input - encoded_images = feature_extractor(image_inputs[0], return_masks=True, return_tensors="pt") + encoded_images = feature_extractor(image_inputs[0], return_image_mask=True, return_tensors="pt") self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75) - -class FLAVACodebookFeatureExtractionTester(unittest.TestCase): - def __init__( - self, - parent, - batch_size=7, - num_channels=3, - min_resolution=30, - max_resolution=400, - do_resize=True, - size=112, - do_center_crop=True, - crop_size=112, - resample=None, - do_normalize=True, - image_mean=FLAVA_CODEBOOK_MEAN, - image_std=FLAVA_CODEBOOK_STD, - do_map_pixels=True, - ): - self.parent = parent - self.batch_size = batch_size - self.num_channels = num_channels - self.do_resize = do_resize - self.min_resolution = min_resolution - self.max_resolution = max_resolution - self.size = size - self.do_center_crop = do_center_crop - self.crop_size = crop_size - self.resample = resample if resample is not None else Image.LANCZOS - self.do_normalize = do_normalize - self.image_mean = image_mean - self.image_std = image_std - self.do_map_pixels = do_map_pixels - - def prepare_feat_extract_dict(self): - return { - "do_resize": self.do_resize, - "size": self.size, - "do_center_crop": self.do_center_crop, - "crop_size": self.crop_size, - "resample": self.resample, - "do_normalize": self.do_normalize, - "image_mean": self.image_mean, - "image_std": self.image_std, - "do_map_pixels": self.do_map_pixels, - } - - def get_expected_image_size(self): - return (self.size, self.size) if not isinstance(self.size, tuple) else self.size - - -@require_torch -@require_vision -class FLAVACodebookFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): - - feature_extraction_class = FLAVACodebookFeatureExtractor if is_vision_available() else None - - def setUp(self): - self.feature_extract_tester = FLAVACodebookFeatureExtractionTester(self) - - @property - def feat_extract_dict(self): - return self.feature_extract_tester.prepare_feat_extract_dict() - - def test_feat_extract_properties(self): - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - self.assertTrue(hasattr(feature_extractor, "image_mean")) - self.assertTrue(hasattr(feature_extractor, "image_std")) - self.assertTrue(hasattr(feature_extractor, "do_normalize")) - self.assertTrue(hasattr(feature_extractor, "do_resize")) - self.assertTrue(hasattr(feature_extractor, "resample")) - self.assertTrue(hasattr(feature_extractor, "crop_size")) - self.assertTrue(hasattr(feature_extractor, "do_center_crop")) - - def test_batch_feature(self): - pass - - def test_call_pil(self): + def test_codebook_pixels(self): # Initialize feature_extractor feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) # create random PIL images @@ -364,38 +326,18 @@ def test_call_pil(self): self.assertIsInstance(image, Image.Image) # Test not batched input - encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") - expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + encoded_images = feature_extractor(image_inputs[0], return_codebook_pixels=True, return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size() self.assertEqual( - encoded_images.pixel_values.shape, + encoded_images.codebook_pixel_values.shape, (1, self.feature_extract_tester.num_channels, expected_height, expected_width), ) # Test batched - encoded_images = feature_extractor(image_inputs, return_tensors="pt") - expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() - - def _test_call_framework(self, instance_class, prepare_kwargs): - # Initialize feature_extractor - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - # create random tensors - image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, **prepare_kwargs) - for image in image_inputs: - self.assertIsInstance(image, instance_class) - - # Test not batched input - encoded_images = feature_extractor(image_inputs[0], return_tensors="pt") - expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() + encoded_images = feature_extractor(image_inputs, return_codebook_pixels=True, return_tensors="pt") + expected_height, expected_width = self.feature_extract_tester.get_expected_codebook_image_size() self.assertEqual( - encoded_images.pixel_values.shape, - (1, self.feature_extract_tester.num_channels, expected_height, expected_width), - ) - - # Test batched - encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values - expected_height, expected_width = self.feature_extract_tester.get_expected_image_size() - self.assertEqual( - encoded_images.shape, + encoded_images.codebook_pixel_values.shape, ( self.feature_extract_tester.batch_size, self.feature_extract_tester.num_channels, @@ -403,14 +345,3 @@ def _test_call_framework(self, instance_class, prepare_kwargs): expected_width, ), ) - - def test_call_numpy(self): - self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True}) - - def test_call_pytorch(self): - self._test_call_framework( - torch.Tensor, - prepare_kwargs={ - "torchify": True, - }, - ) diff --git a/tests/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py similarity index 88% rename from tests/flava/test_modeling_flava.py rename to tests/models/flava/test_modeling_flava.py index 215907110ff87..95dffe1e61964 100644 --- a/tests/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,12 +24,18 @@ import numpy as np import requests -from transformers import FLAVACodebookConfig, FLAVAConfig, FLAVAImageConfig, FLAVAMultimodalConfig, FLAVATextConfig +from transformers import ( + FlavaConfig, + FlavaImageCodebookConfig, + FlavaImageConfig, + FlavaMultimodalConfig, + FlavaTextConfig, +) from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.utils import is_torch_available, is_vision_available -from ..test_configuration_common import ConfigTester -from ..test_modeling_common import ( +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( ModelTesterMixin, _config_zero_init, floats_tensor, @@ -43,30 +49,31 @@ from torch import nn from transformers import ( - FLAVACodebook, - FLAVAForPreTraining, - FLAVAImageModel, - FLAVAModel, - FLAVAMultimodalModel, - FLAVATextModel, + FlavaForPreTraining, + FlavaImageCodebook, + FlavaImageModel, + FlavaModel, + FlavaMultimodalModel, + FlavaTextModel, ) from transformers.models.flava.modeling_flava import ( FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST, FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, + add_whole_word_masked_text, ) else: - FLAVAModel = None - FLAVAForPreTraining = None + FlavaModel = None + FlavaForPreTraining = None torch = {} if is_vision_available(): from PIL import Image - from transformers import FLAVACodebookFeatureExtractor, FLAVAProcessor + from transformers import FlavaProcessor -class FLAVAImageModelTester: +class FlavaImageModelTester: def __init__( self, parent, @@ -115,7 +122,7 @@ def prepare_config_and_inputs(self): return config, pixel_values, bool_masked_pos def get_config(self): - return FLAVAImageConfig( + return FlavaImageConfig( hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, @@ -134,7 +141,7 @@ def get_config(self): ) def create_and_check_model(self, config, pixel_values, bool_masked_pos): - model = FLAVAImageModel(config=config) + model = FlavaImageModel(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -154,13 +161,13 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class FLAVAImageModelTest(ModelTesterMixin, unittest.TestCase): +class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase): """ Here we also overwrite some of the tests of test_modeling_common.py, as FLAVA does not use input_ids, inputs_embeds, attention_mask and seq_length. """ - all_model_classes = (FLAVAImageModel,) if is_torch_available() else () + all_model_classes = (FlavaImageModel,) if is_torch_available() else () test_pruning = False test_torchscript = False @@ -168,8 +175,8 @@ class FLAVAImageModelTest(ModelTesterMixin, unittest.TestCase): test_head_masking = False def setUp(self): - self.model_tester = FLAVAImageModelTester(self) - self.config_tester = ConfigTester(self, config_class=FLAVAImageConfig, has_text_modality=False, hidden_size=37) + self.model_tester = FlavaImageModelTester(self) + self.config_tester = ConfigTester(self, config_class=FlavaImageConfig, has_text_modality=False, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -304,12 +311,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass - # skip this test as FLAVAImageModel has no base class and is + # skip this test as FlavaImageModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): pass - # skip this test as FLAVAImageModel has no base class and is + # skip this test as FlavaImageModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_to_base(self): pass @@ -317,11 +324,11 @@ def test_save_load_fast_init_to_base(self): @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = FLAVAImageModel.from_pretrained(model_name) + model = FlavaImageModel.from_pretrained(model_name) self.assertIsNotNone(model) -class FLAVATextModelTester: +class FlavaTextModelTester: def __init__( self, parent, @@ -392,7 +399,7 @@ def prepare_config_and_inputs(self): return config, input_ids, token_type_ids, input_mask def get_config(self): - return FLAVATextConfig( + return FlavaTextConfig( vocab_size=self.vocab_size, type_vocab_size=self.type_vocab_size, max_position_embeddings=self.max_position_embeddings, @@ -411,7 +418,7 @@ def get_config(self): ) def create_and_check_model(self, config, input_ids, token_type_ids, input_mask): - model = FLAVATextModel(config=config) + model = FlavaTextModel(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -428,16 +435,16 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class FLAVATextModelTest(ModelTesterMixin, unittest.TestCase): +class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (FLAVATextModel,) if is_torch_available() else () + all_model_classes = (FlavaTextModel,) if is_torch_available() else () test_pruning = False test_head_masking = False test_torchscript = False def setUp(self): - self.model_tester = FLAVATextModelTester(self) - self.config_tester = ConfigTester(self, config_class=FLAVATextConfig, hidden_size=37) + self.model_tester = FlavaTextModelTester(self) + self.config_tester = ConfigTester(self, config_class=FlavaTextConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -456,12 +463,12 @@ def test_inputs_embeds(self): # FLAVA does not use inputs_embeds pass - # skip this test as FLAVATextModel has no base class and is + # skip this test as FlavaTextModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): pass - # skip this test as FLAVATextModel has no base class and is + # skip this test as FlavaTextModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_to_base(self): pass @@ -469,11 +476,11 @@ def test_save_load_fast_init_to_base(self): @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = FLAVATextModel.from_pretrained(model_name) + model = FlavaTextModel.from_pretrained(model_name) self.assertIsNotNone(model) -class FLAVAMultimodalModelTester: +class FlavaMultimodalModelTester: def __init__( self, parent, @@ -529,7 +536,7 @@ def prepare_config_and_inputs(self): return config, hidden_states, input_mask def get_config(self): - return FLAVAMultimodalConfig( + return FlavaMultimodalConfig( hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, @@ -545,7 +552,7 @@ def get_config(self): ) def create_and_check_model(self, config, hidden_states, input_mask): - model = FLAVAMultimodalModel(config=config) + model = FlavaMultimodalModel(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -562,18 +569,18 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class FLAVAMultimodalModelTest(ModelTesterMixin, unittest.TestCase): +class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (FLAVAMultimodalModel,) if is_torch_available() else () + all_model_classes = (FlavaMultimodalModel,) if is_torch_available() else () test_pruning = False test_head_masking = False test_resize_embeddings = False test_torchscript = False def setUp(self): - self.model_tester = FLAVAMultimodalModelTester(self) + self.model_tester = FlavaMultimodalModelTester(self) self.config_tester = ConfigTester( - self, config_class=FLAVAMultimodalConfig, has_text_modality=False, hidden_size=37 + self, config_class=FlavaMultimodalConfig, has_text_modality=False, hidden_size=37 ) def test_config(self): @@ -609,12 +616,12 @@ def test_inputs_embeds(self): # FLAVA does not use inputs_embeds pass - # skip this test as FLAVAMultimodalModel has no base class and is + # skip this test as FlavaMultimodalModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): pass - # skip this test as FLAVAMultimodalModel has no base class and is + # skip this test as FlavaMultimodalModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_to_base(self): pass @@ -622,18 +629,12 @@ def test_save_load_fast_init_to_base(self): @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = FLAVAMultimodalModel.from_pretrained(model_name) + model = FlavaMultimodalModel.from_pretrained(model_name) self.assertIsNotNone(model) -class FLAVACodebookTester: - def __init__( - self, - parent, - batch_size=12, - image_size=112, - num_channels=3, - ): +class FlavaImageCodebookTester: + def __init__(self, parent, batch_size=12, image_size=112, num_channels=3): self.parent = parent self.batch_size = batch_size self.image_size = image_size @@ -646,10 +647,10 @@ def prepare_config_and_inputs(self): return config, pixel_values def get_config(self): - return FLAVACodebookConfig() + return FlavaImageCodebookConfig() def create_and_check_model(self, config, pixel_values): - model = FLAVACodebook(config=config) + model = FlavaImageCodebook(config=config) model.to(torch_device) model.eval() with torch.no_grad(): @@ -666,9 +667,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class FLAVACodebookTest(ModelTesterMixin, unittest.TestCase): +class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (FLAVACodebook,) if is_torch_available() else () + all_model_classes = (FlavaImageCodebook,) if is_torch_available() else () test_pruning = False test_head_masking = False test_resize_embeddings = False @@ -676,8 +677,8 @@ class FLAVACodebookTest(ModelTesterMixin, unittest.TestCase): has_attentions = False def setUp(self): - self.model_tester = FLAVACodebookTester(self) - self.config_tester = ConfigTester(self, config_class=FLAVACodebookConfig, has_text_modality=False) + self.model_tester = FlavaImageCodebookTester(self) + self.config_tester = ConfigTester(self, config_class=FlavaImageCodebookConfig, has_text_modality=False) def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -719,12 +720,12 @@ def test_inputs_embeds(self): def test_model_outputs_equivalence(self): pass - # skip this test as FLAVACodebook has no base class and is + # skip this test as FlavaImageCodebook has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): pass - # skip this test as FLAVACodebook has no base class and is + # skip this test as FlavaImageCodebook has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_to_base(self): pass @@ -732,12 +733,12 @@ def test_save_load_fast_init_to_base(self): @slow def test_model_from_pretrained(self): for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = FLAVACodebook.from_pretrained(model_name) + model = FlavaImageCodebook.from_pretrained(model_name) self.assertIsNotNone(model) -class FLAVAModelTester: - model_class = FLAVAModel +class FlavaModelTester: + model_class = FlavaModel def __init__( self, @@ -749,11 +750,12 @@ def __init__( layer_norm_eps=1e-12, ): self.parent = parent - self.image_model_tester = FLAVAImageModelTester(parent) - self.text_model_tester = FLAVATextModelTester(parent) - self.multimodal_model_tester = FLAVAMultimodalModelTester(parent) + self.image_model_tester = FlavaImageModelTester(parent) + self.text_model_tester = FlavaTextModelTester(parent) + self.multimodal_model_tester = FlavaMultimodalModelTester(parent) + self.image_codebook_tester = FlavaImageCodebookTester(parent) self.is_training = is_training - self.config_tester = ConfigTester(self, config_class=FLAVAConfig, hidden_size=37) + self.config_tester = ConfigTester(self, config_class=FlavaConfig, hidden_size=37) self.hidden_size = hidden_size self.projection_dim = projection_dim self.initializer_range = initializer_range @@ -777,10 +779,11 @@ def prepare_config_and_inputs_for_common(self): } def get_config(self): - return FLAVAConfig.from_configs( + return FlavaConfig.from_configs( self.image_model_tester.get_config(), self.text_model_tester.get_config(), self.multimodal_model_tester.get_config(), + self.image_codebook_tester.get_config(), hidden_size=self.hidden_size, projection_dim=self.projection_dim, initializer_range=self.initializer_range, @@ -792,13 +795,7 @@ def create_and_check_model(self, config, inputs): self._test_model(config, inputs, test_text=True) self._test_model(config, inputs, test_image=True, test_text=True) - def _test_model( - self, - config, - inputs, - test_image=False, - test_text=False, - ): + def _test_model(self, config, inputs, test_image=False, test_text=False): model = self.model_class(config).to(torch_device).eval() with torch.no_grad(): result = model( @@ -846,9 +843,9 @@ def _test_model( @require_torch -class FLAVAModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (FLAVAModel,) if is_torch_available() else () - class_for_tester = FLAVAModelTester +class FlavaModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (FlavaModel,) if is_torch_available() else () + class_for_tester = FlavaModelTester test_head_masking = False test_pruning = False test_resize_embeddings = False @@ -873,7 +870,7 @@ def test_inputs_embeds(self): def test_retain_grad_hidden_states_attentions(self): pass - # FLAVAModel does not have input/output embeddings + # FlavaModel does not have input/output embeddings def test_model_common_attributes(self): pass @@ -965,34 +962,34 @@ def _create_and_check_torchscript(self, config, inputs_dict): def test_load_image_text_config(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - # Save FLAVAConfig and check if we can load FLAVAImageConfig from it + # Save FlavaConfig and check if we can load FlavaImageConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: config.save_pretrained(tmp_dir_name) - image_config = FLAVAImageConfig.from_pretrained(tmp_dir_name) + image_config = FlavaImageConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.image_config.to_dict(), image_config.to_dict()) - # Save FLAVAConfig and check if we can load FLAVATextConfig from it + # Save FlavaConfig and check if we can load FlavaTextConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: config.save_pretrained(tmp_dir_name) - text_config = FLAVATextConfig.from_pretrained(tmp_dir_name) + text_config = FlavaTextConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) - # Save FLAVAConfig and check if we can load FLAVAMultimodalConfig from it + # Save FlavaConfig and check if we can load FlavaMultimodalConfig from it with tempfile.TemporaryDirectory() as tmp_dir_name: config.save_pretrained(tmp_dir_name) - multimodal_config = FLAVAMultimodalConfig.from_pretrained(tmp_dir_name) + multimodal_config = FlavaMultimodalConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.multimodal_config.to_dict(), multimodal_config.to_dict()) - # overwrite from common since FLAVAModel/TFFLAVAModel return FLAVAOutput/TFFLAVAOutput + # overwrite from common since FlavaModel/TFFlavaModel return FLAVAOutput/TFFLAVAOutput @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = FLAVAModel.from_pretrained(model_name) + model = FlavaModel.from_pretrained(model_name) self.assertIsNotNone(model) -class FLAVAForPreTrainingTester(FLAVAModelTester): - model_class = FLAVAForPreTraining +class FlavaForPreTrainingTester(FlavaModelTester): + model_class = FlavaForPreTraining def prepare_config_and_inputs_for_common(self): _, pixel_values, bool_masked_pos = self.image_model_tester.prepare_config_and_inputs() @@ -1023,13 +1020,7 @@ def prepare_config_and_inputs_for_common(self): "return_loss": True, } - def _test_model( - self, - config, - inputs, - test_image=False, - test_text=False, - ): + def _test_model(self, config, inputs, test_image=False, test_text=False): model = self.model_class(config).to(torch_device).eval() with torch.no_grad(): result = model( @@ -1145,9 +1136,9 @@ def _test_model( @require_torch -class FLAVAForPreTrainingTest(FLAVAModelTest): - all_model_classes = (FLAVAForPreTraining,) if is_torch_available() else () - class_for_tester = FLAVAForPreTrainingTester +class FlavaForPreTrainingTest(FlavaModelTest): + all_model_classes = (FlavaForPreTraining,) if is_torch_available() else () + class_for_tester = FlavaForPreTrainingTester test_torchscript = False @@ -1160,12 +1151,12 @@ def prepare_img(): @require_vision @require_torch -class FLAVAModelIntegrationTest(unittest.TestCase): +class FlavaModelIntegrationTest(unittest.TestCase): @slow def test_inference(self): - model_name = "aps/flava-full" - model = FLAVAModel.from_pretrained(model_name).to(torch_device) - processor = FLAVAProcessor.from_pretrained(model_name) + model_name = "facebook/flava-full" + model = FlavaModel.from_pretrained(model_name).to(torch_device) + processor = FlavaProcessor.from_pretrained(model_name) image = prepare_img() inputs = processor( @@ -1174,7 +1165,6 @@ def test_inference(self): padding="max_length", max_length=77, return_tensors="pt", - return_masks=False, ).to(torch_device) # forward pass @@ -1189,29 +1179,28 @@ def test_inference(self): @require_vision @require_torch -class FLAVAForPreTrainingIntegrationTest(unittest.TestCase): +class FlavaForPreTrainingIntegrationTest(unittest.TestCase): @slow def test_inference(self): - model_name = "aps/flava-full" - codebook_name = "aps/flava-codebook" - model = FLAVAForPreTraining.from_pretrained(model_name).to(torch_device) - codebook = FLAVACodebook.from_pretrained(codebook_name).to(torch_device) - codebook_fe = FLAVACodebookFeatureExtractor.from_pretrained(codebook_name) - processor = FLAVAProcessor.from_pretrained(model_name) + model_name = "facebook/flava-full" + model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device) + processor = FlavaProcessor.from_pretrained(model_name) torch.manual_seed(1) random.seed(1) image = prepare_img() - mim_labels = codebook.get_codebook_indices(**codebook_fe([image, image], return_tensors="pt").to(torch_device)) inputs = processor( text=["a photo of a cat", "a photo of a dog"], images=[image, image], padding="max_length", max_length=77, return_tensors="pt", - return_masks=True, - ).to(torch_device) - inputs["mim_labels"] = mim_labels + return_codebook_pixels=True, + return_special_tokens_mask=True, + return_image_mask=True, + ) + inputs = add_whole_word_masked_text(inputs, processor.tokenizer) + inputs = inputs.to(torch_device) # forward pass with torch.no_grad(): diff --git a/tests/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py similarity index 71% rename from tests/flava/test_processor_flava.py rename to tests/models/flava/test_processor_flava.py index b33d42ae5fa52..21cc84d5f299a 100644 --- a/tests/flava/test_processor_flava.py +++ b/tests/models/flava/test_processor_flava.py @@ -1,4 +1,4 @@ -# Copyright 2021 The HuggingFace Team. All rights reserved. +# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ import json import os +import random import shutil import tempfile import unittest @@ -30,31 +31,23 @@ if is_vision_available(): from PIL import Image - from transformers import FLAVAFeatureExtractor, FLAVAProcessor - from transformers.models.flava.feature_extraction_flava import FLAVA_IMAGE_MEAN, FLAVA_IMAGE_STD + from transformers import FlavaFeatureExtractor, FlavaProcessor + from transformers.models.flava.feature_extraction_flava import ( + FLAVA_CODEBOOK_MEAN, + FLAVA_CODEBOOK_STD, + FLAVA_IMAGE_MEAN, + FLAVA_IMAGE_STD, + ) @require_vision -class FLAVAProcessorTest(unittest.TestCase): +class FlavaProcessorTest(unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - vocab_tokens = [ - "[UNK]", - "[CLS]", - "[SEP]", - "[PAD]", - "[MASK]", - "want", - "##want", - "##ed", - "wa", - "un", - "runn", - "##ing", - ",", - "low", - "lowest", - ] + + # fmt: off + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] + # fmt: on self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) with open(self.vocab_file, "w", encoding="utf-8") as fp: @@ -74,6 +67,15 @@ def setUp(self): "mask_group_min_patches": 16, "mask_group_min_aspect_ratio": 0.3, "mask_group_max_aspect_ratio": None, + "codebook_do_resize": True, + "codebook_size": 112, + "codebook_resample": None, + "codebook_do_center_crop": True, + "codebook_crop_size": 112, + "codebook_do_map_pixels": True, + "codebook_do_normalize": True, + "codebook_image_mean": FLAVA_CODEBOOK_MEAN, + "codebook_image_std": FLAVA_CODEBOOK_STD, } self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) @@ -87,7 +89,7 @@ def get_rust_tokenizer(self, **kwargs): return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) def get_feature_extractor(self, **kwargs): - return FLAVAFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + return FlavaFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -108,13 +110,13 @@ def test_save_load_pretrained_default(self): tokenizer_fast = self.get_rust_tokenizer() feature_extractor = self.get_feature_extractor() - processor_slow = FLAVAProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor) + processor_slow = FlavaProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor) processor_slow.save_pretrained(self.tmpdirname) - processor_slow = FLAVAProcessor.from_pretrained(self.tmpdirname, use_fast=False) + processor_slow = FlavaProcessor.from_pretrained(self.tmpdirname, use_fast=False) - processor_fast = FLAVAProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor) + processor_fast = FlavaProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor) processor_fast.save_pretrained(self.tmpdirname) - processor_fast = FLAVAProcessor.from_pretrained(self.tmpdirname) + processor_fast = FlavaProcessor.from_pretrained(self.tmpdirname) self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab()) self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab()) @@ -124,17 +126,17 @@ def test_save_load_pretrained_default(self): self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string()) self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string()) - self.assertIsInstance(processor_slow.feature_extractor, FLAVAFeatureExtractor) - self.assertIsInstance(processor_fast.feature_extractor, FLAVAFeatureExtractor) + self.assertIsInstance(processor_slow.feature_extractor, FlavaFeatureExtractor) + self.assertIsInstance(processor_fast.feature_extractor, FlavaFeatureExtractor) def test_save_load_pretrained_additional_features(self): - processor = FLAVAProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + processor = FlavaProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) processor.save_pretrained(self.tmpdirname) tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) - processor = FLAVAProcessor.from_pretrained( + processor = FlavaProcessor.from_pretrained( self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 ) @@ -142,13 +144,13 @@ def test_save_load_pretrained_additional_features(self): self.assertIsInstance(processor.tokenizer, BertTokenizerFast) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) - self.assertIsInstance(processor.feature_extractor, FLAVAFeatureExtractor) + self.assertIsInstance(processor.feature_extractor, FlavaFeatureExtractor) def test_feature_extractor(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() - processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) image_input = self.prepare_image_inputs() @@ -158,11 +160,24 @@ def test_feature_extractor(self): for key in input_feat_extract.keys(): self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + # With rest of the args + random.seed(1234) + input_feat_extract = feature_extractor( + image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np" + ) + random.seed(1234) + input_processor = processor( + images=image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np" + ) + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + def test_tokenizer(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() - processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) input_str = "lower newer" @@ -177,7 +192,7 @@ def test_processor(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() - processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) input_str = "lower newer" image_input = self.prepare_image_inputs() @@ -186,6 +201,21 @@ def test_processor(self): self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"]) + # add extra args + inputs = processor(text=input_str, images=image_input, return_codebook_pixels=True, return_image_mask=True) + + self.assertListEqual( + list(inputs.keys()), + [ + "input_ids", + "token_type_ids", + "attention_mask", + "pixel_values", + "codebook_pixel_values", + "bool_masked_pos", + ], + ) + # test if it raises when no input is passed with pytest.raises(ValueError): processor() @@ -194,7 +224,7 @@ def test_tokenizer_decode(self): feature_extractor = self.get_feature_extractor() tokenizer = self.get_tokenizer() - processor = FLAVAProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] diff --git a/utils/check_repo.py b/utils/check_repo.py index 2fd4de5699a6c..aa9faadb0a6f6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -146,10 +146,10 @@ "DetrForSegmentation", "DPRReader", "FlaubertForQuestionAnswering", - "FLAVACodebook", - "FLAVATextModel", - "FLAVAImageModel", - "FLAVAMultimodalModel", + "FlavaImageCodebook", + "FlavaTextModel", + "FlavaImageModel", + "FlavaMultimodalModel", "GPT2DoubleHeadsModel", "LukeForMaskedLM", "LukeForEntityClassification",