diff --git a/README.md b/README.md index ec8a0fd2e8b392..5e17e33b204cc1 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino. 1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli. diff --git a/README_ko.md b/README_ko.md index 8a2df2fcbc8887..f53075ff5fe6f9 100644 --- a/README_ko.md +++ b/README_ko.md @@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino. 1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli. diff --git a/README_zh-hans.md b/README_zh-hans.md index 88611a5f672bd0..2843a8eb29a08c 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -359,6 +359,7 @@ conda install -c huggingface transformers 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (来自 UCLA NLP) 伴随论文 [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) 由 Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang 发布。 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (来自 Meta AI) 伴随论文 [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) 由 Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick 发布。 +1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (来自 Meta AI) 伴随论文 [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas 发布. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。 1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (来自 Facebook AI) 伴随论文 [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) 由 Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino 发布。 1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (来自 Facebook AI) 伴随论文 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 由 Qiantong Xu, Alexei Baevski, Michael Auli 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index b84d3ea8ca974b..8f74b97e98549f 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -371,6 +371,7 @@ conda install -c huggingface transformers 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino. 1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a4cd1005e3e83f..223c5d2a6998fc 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -408,6 +408,8 @@ title: Vision Transformer (ViT) - local: model_doc/vit_mae title: ViTMAE + - local: model_doc/vit_msn + title: ViTMSN - local: model_doc/yolos title: YOLOS title: Vision models diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 40685a5c2fa8da..e6a3d912b27437 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -175,6 +175,7 @@ The documentation is organized into five sections: 1. **[Vision Transformer (ViT)](model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VisualBERT](model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. 1. **[ViTMAE](model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMSN](model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[Wav2Vec2](model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[Wav2Vec2-Conformer](model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino. 1. **[Wav2Vec2Phoneme](model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli. @@ -318,6 +319,7 @@ Flax), PyTorch, and/or TensorFlow. | VisualBERT | ❌ | ❌ | ✅ | ❌ | ❌ | | ViT | ❌ | ❌ | ✅ | ✅ | ✅ | | ViTMAE | ❌ | ❌ | ✅ | ✅ | ❌ | +| ViTMSN | ❌ | ❌ | ✅ | ❌ | ❌ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/vit_msn.mdx b/docs/source/en/model_doc/vit_msn.mdx new file mode 100644 index 00000000000000..07faed51e6cb57 --- /dev/null +++ b/docs/source/en/model_doc/vit_msn.mdx @@ -0,0 +1,64 @@ + + +# ViTMSN + +## Overview + +The ViTMSN model was proposed in [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, +Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. The paper presents a joint-embedding architecture to match the prototypes +of masked patches with that of the unmasked patches. With this setup, their method yields excellent performance in the low-shot and extreme low-shot +regimes. + +The abstract from the paper is the following: + +*We propose Masked Siamese Networks (MSN), a self-supervised learning framework for learning image representations. Our +approach matches the representation of an image view containing randomly masked patches to the representation of the original +unmasked image. This self-supervised pre-training strategy is particularly scalable when applied to Vision Transformers since only the +unmasked patches are processed by the network. As a result, MSNs improve the scalability of joint-embedding architectures, +while producing representations of a high semantic level that perform competitively on low-shot image classification. For instance, +on ImageNet-1K, with only 5,000 annotated images, our base MSN model achieves 72.4% top-1 accuracy, +and with 1% of ImageNet-1K labels, we achieve 75.7% top-1 accuracy, setting a new state-of-the-art for self-supervised learning on this benchmark.* + +Tips: + +- MSN (masked siamese networks) is a method for self-supervised pre-training of Vision Transformers (ViTs). The pre-training +objective is to match the prototypes assigned to the unmasked views of the images to that of the masked views of the same images. +- The authors have only released pre-trained weights of the backbone (ImageNet-1k pre-training). So, to use that on your own image classification dataset, +use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMSNModel`]. Follow +[this notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) for a detailed tutorial on fine-tuning. +- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K +labels when fine-tuned. + + +drawing + + MSN architecture. Taken from the original paper. + +This model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/facebookresearch/msn). + + +## ViTMSNConfig + +[[autodoc]] ViTMSNConfig + + +## ViTMSNModel + +[[autodoc]] ViTMSNModel + - forward + + +## ViTMSNForImageClassification + +[[autodoc]] ViTMSNForImageClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 123c56bdd53267..50fb4d2c0b8a7e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -359,6 +359,7 @@ "models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"], "models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], "models.vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"], + "models.vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"], "models.wav2vec2": [ "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", @@ -1966,6 +1967,14 @@ "ViTMAEPreTrainedModel", ] ) + _import_structure["models.vit_msn"].extend( + [ + "VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST", + "ViTMSNModel", + "ViTMSNForImageClassification", + "ViTMSNPreTrainedModel", + ] + ) _import_structure["models.videomae"].extend( [ "VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3251,6 +3260,7 @@ from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .models.vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig + from .models.vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config, @@ -4586,6 +4596,12 @@ ViTMAEModel, ViTMAEPreTrainedModel, ) + from .models.vit_msn import ( + VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST, + ViTMSNForImageClassification, + ViTMSNModel, + ViTMSNPreTrainedModel, + ) from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2ForAudioFrameClassification, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8a6622a9f35b6f..18c21cdf1865b3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -150,6 +150,7 @@ visual_bert, vit, vit_mae, + vit_msn, wav2vec2, wav2vec2_conformer, wav2vec2_phoneme, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index dce73cb3903656..39c48b217ff53f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -145,6 +145,7 @@ ("visual_bert", "VisualBertConfig"), ("vit", "ViTConfig"), ("vit_mae", "ViTMAEConfig"), + ("vit_msn", "ViTMSNConfig"), ("wav2vec2", "Wav2Vec2Config"), ("wav2vec2-conformer", "Wav2Vec2ConformerConfig"), ("wavlm", "WavLMConfig"), @@ -266,6 +267,7 @@ ("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vit_msn", "VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("wav2vec2-conformer", "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("xclip", "X_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -418,6 +420,7 @@ ("visual_bert", "VisualBERT"), ("vit", "ViT"), ("vit_mae", "ViTMAE"), + ("vit_msn", "ViTMSN"), ("wav2vec2", "Wav2Vec2"), ("wav2vec2-conformer", "Wav2Vec2-Conformer"), ("wav2vec2_phoneme", "Wav2Vec2Phoneme"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 3f2265875f2a69..73fe1ad42ad195 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -74,6 +74,7 @@ ("vilt", "ViltFeatureExtractor"), ("vit", "ViTFeatureExtractor"), ("vit_mae", "ViTFeatureExtractor"), + ("vit_msn", "ViTFeatureExtractor"), ("wav2vec2", "Wav2Vec2FeatureExtractor"), ("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"), ("xclip", "CLIPFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d7c5f1772f13bd..936e9c8bdc479c 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -139,6 +139,7 @@ ("visual_bert", "VisualBertModel"), ("vit", "ViTModel"), ("vit_mae", "ViTMAEModel"), + ("vit_msn", "ViTMSNModel"), ("wav2vec2", "Wav2Vec2Model"), ("wav2vec2-conformer", "Wav2Vec2ConformerModel"), ("wavlm", "WavLMModel"), @@ -367,6 +368,7 @@ ("swinv2", "Swinv2ForImageClassification"), ("van", "VanForImageClassification"), ("vit", "ViTForImageClassification"), + ("vit_msn", "ViTMSNForImageClassification"), ] ) diff --git a/src/transformers/models/vit_msn/__init__.py b/src/transformers/models/vit_msn/__init__.py new file mode 100644 index 00000000000000..832e730c5881c6 --- /dev/null +++ b/src/transformers/models/vit_msn/__init__.py @@ -0,0 +1,57 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"]} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_vit_msn"] = [ + "VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST", + "ViTMSNModel", + "ViTMSNForImageClassification", + "ViTMSNPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_vit_msn import ( + VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST, + ViTMSNForImageClassification, + ViTMSNModel, + ViTMSNPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/vit_msn/configuration_vit_msn.py b/src/transformers/models/vit_msn/configuration_vit_msn.py new file mode 100644 index 00000000000000..057824e5d4e133 --- /dev/null +++ b/src/transformers/models/vit_msn/configuration_vit_msn.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ViT MSN model configuration""" + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "sayakpaul/vit-msn-base": "https://huggingface.co/sayakpaul/vit-msn-base/resolve/main/config.json", + # See all ViT MSN models at https://huggingface.co/models?filter=vit_msn +} + + +class ViTMSNConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT + MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with + the defaults will yield a similar configuration to that of the ViT + [facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import ViTMSNModel, ViTMSNConfig + + >>> # Initializing a ViT MSN vit-msn-base style configuration + >>> configuration = ViTConfig() + + >>> # Initializing a model from the vit-msn-base style configuration + >>> model = ViTMSNModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "vit_msn" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias diff --git a/src/transformers/models/vit_msn/convert_msn_to_pytorch.py b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py new file mode 100644 index 00000000000000..535f5f742d631f --- /dev/null +++ b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn""" + +import argparse + +import torch +from PIL import Image + +import requests +from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +torch.set_grad_enabled(False) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("module.cls_token", "vit.embeddings.cls_token"), + ("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("module.pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("module.norm.weight", "layernorm.weight"), + ("module.norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def remove_projection_head(state_dict): + # projection head is used in the self-supervised pre-training in MSN, + # for downstream task it's not needed. + ignore_keys = [ + "module.fc.fc1.weight", + "module.fc.fc1.bias", + "module.fc.bn1.weight", + "module.fc.bn1.bias", + "module.fc.bn1.running_mean", + "module.fc.bn1.running_var", + "module.fc.bn1.num_batches_tracked", + "module.fc.fc2.weight", + "module.fc.fc2.bias", + "module.fc.bn2.weight", + "module.fc.bn2.bias", + "module.fc.bn2.running_mean", + "module.fc.bn2.running_var", + "module.fc.bn2.num_batches_tracked", + "module.fc.fc3.weight", + "module.fc.fc3.bias", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path): + config = ViTMSNConfig() + config.num_labels = 1000 + + if "s16" in checkpoint_url: + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_attention_heads = 6 + elif "l16" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.hidden_dropout_prob = 0.1 + elif "b4" in checkpoint_url: + config.patch_size = 4 + elif "l7" in checkpoint_url: + config.patch_size = 7 + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.hidden_dropout_prob = 0.1 + + model = ViTMSNModel(config) + + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"] + + feature_extractor = ViTFeatureExtractor(size=config.image_size) + + remove_projection_head(state_dict) + rename_keys = create_rename_keys(config, base_model=True) + + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model=True) + + model.load_state_dict(state_dict) + model.eval() + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + image = Image.open(requests.get(url, stream=True).raw) + feature_extractor = ViTFeatureExtractor( + size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD + ) + inputs = feature_extractor(images=image, return_tensors="pt") + + # forward pass + torch.manual_seed(2) + outputs = model(**inputs) + last_hidden_state = outputs.last_hidden_state + + # The following Colab Notebook was used to generate these outputs: + # https://colab.research.google.com/gist/sayakpaul/3672419a04f5997827503fd84079bdd1/scratchpad.ipynb + if "s16" in checkpoint_url: + expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]]) + elif "b16" in checkpoint_url: + expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]]) + elif "l16" in checkpoint_url: + expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]]) + elif "b4" in checkpoint_url: + expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]]) + else: + expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]]) + + # verify logits + assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4) + + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar", + type=str, + help="URL of the checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py new file mode 100644 index 00000000000000..314b3dbd5bfdbd --- /dev/null +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -0,0 +1,695 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViT MSN (masked siamese network) model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_vit_msn import ViTMSNConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "ViTMSNConfig" +_CHECKPOINT_FOR_DOC = "sayakpaul/vit-msn-small" +VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "sayakpaul/vit-msn-small", + # See all ViTMSN models at https://huggingface.co/models?filter=vit_msn +] + + +class ViTMSNEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTMSNPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + patch_window_height = height // self.config.patch_size + patch_window_width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + patch_window_height / math.sqrt(num_positions), + patch_window_width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN +class ViTMSNPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN +class ViTMSNSelfAttention(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN +class ViTMSNSelfOutput(nn.Module): + """ + The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN +class ViTMSNAttention(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.attention = ViTMSNSelfAttention(config) + self.output = ViTMSNSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN +class ViTMSNIntermediate(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN +class ViTMSNOutput(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN +class ViTMSNLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTMSNAttention(config) + self.intermediate = ViTMSNIntermediate(config) + self.output = ViTMSNOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViTMSN, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN +class ViTMSNEncoder(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTMSNPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTMSNConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 + # when creating pre-training scripts. + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: + if isinstance(module, ViTMSNEncoder): + module.gradient_checkpointing = value + + +VIT_MSN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_MSN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViTMSN Model outputting raw hidden-states without any specific head on top.", + VIT_MSN_START_DOCSTRING, +) +class ViTMSNModel(ViTMSNPreTrainedModel): + def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTMSNEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTMSNPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, ViTMSNModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-msn-small") + >>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small") + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Caution: We don't have the weights for the classification head yet. This class +# is here for the users that are interested to fine-tune the base model (ViTMSNModel). +@add_start_docstrings( + """ + ViTMSN Model with an image classification head on top e.g. for ImageNet. + """, + VIT_MSN_START_DOCSTRING, +) +class ViTMSNForImageClassification(ViTMSNPreTrainedModel): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.vit = ViTMSNModel(config) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoFeatureExtractor, ViTMSNForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-msn-small") + >>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small") + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b656cee9c89bdc..e9f1bae358f3ac 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5150,6 +5150,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class ViTMSNForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ViTMSNModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ViTMSNPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/vit_msn/__init__.py b/tests/models/vit_msn/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/vit_msn/test_modeling_vit_msn.py b/tests/models/vit_msn/test_modeling_vit_msn.py new file mode 100644 index 00000000000000..b858da42b3d4e3 --- /dev/null +++ b/tests/models/vit_msn/test_modeling_vit_msn.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch ViTMSN model. """ + + +import inspect +import unittest + +from transformers import ViTMSNConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import ViTMSNForImageClassification, ViTMSNModel + from transformers.models.vit_msn.modeling_vit_msn import VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import ViTFeatureExtractor + + +class ViTMAEModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + # in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return ViTMSNConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = ViTMSNModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = ViTMSNForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + print("Pixel and labels shape: {pixel_values.shape}, {labels.shape}") + print("Labels: {labels}") + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # test greyscale images + config.num_channels = 1 + model = ViTMSNForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ViTMSNModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as ViTMAE does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else () + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = ViTMAEModelTester(self) + self.config_tester = ConfigTester(self, config_class=ViTMSNConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="ViTMAE does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = ViTMSNModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class ViTMSNModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ViTFeatureExtractor.from_pretrained("sayakpaul/vit-msn-small") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + torch.manual_seed(2) + model = ViTMSNForImageClassification.from_pretrained("sayakpaul/vit-msn-small").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-0.0803, -0.4454, -0.2375]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index eee74d8940f595..3a7dfa55041d73 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -85,6 +85,7 @@ src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.p src/transformers/models/vit/modeling_vit.py src/transformers/models/vit/modeling_tf_vit.py src/transformers/models/vit_mae/modeling_vit_mae.py +src/transformers/models/vit_msn/modeling_vit_msn.py src/transformers/models/wav2vec2/modeling_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py