diff --git a/README.md b/README.md
index ec8a0fd2e8b392..5e17e33b204cc1 100644
--- a/README.md
+++ b/README.md
@@ -385,6 +385,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
+1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
diff --git a/README_ko.md b/README_ko.md
index 8a2df2fcbc8887..f53075ff5fe6f9 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
+1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 88611a5f672bd0..2843a8eb29a08c 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -359,6 +359,7 @@ conda install -c huggingface transformers
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (来自 UCLA NLP) 伴随论文 [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) 由 Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang 发布。
1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (来自 Meta AI) 伴随论文 [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) 由 Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick 发布。
+1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (来自 Meta AI) 伴随论文 [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas 发布.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (来自 Facebook AI) 伴随论文 [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) 由 Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino 发布。
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (来自 Facebook AI) 伴随论文 [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) 由 Qiantong Xu, Alexei Baevski, Michael Auli 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index b84d3ea8ca974b..8f74b97e98549f 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -371,6 +371,7 @@ conda install -c huggingface transformers
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
+1. **[ViTMSN](https://huggingface.co/docs/transformers/main/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-Conformer](https://huggingface.co/docs/transformers/model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](https://huggingface.co/docs/transformers/model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index a4cd1005e3e83f..223c5d2a6998fc 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -408,6 +408,8 @@
title: Vision Transformer (ViT)
- local: model_doc/vit_mae
title: ViTMAE
+ - local: model_doc/vit_msn
+ title: ViTMSN
- local: model_doc/yolos
title: YOLOS
title: Vision models
diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index 40685a5c2fa8da..e6a3d912b27437 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -175,6 +175,7 @@ The documentation is organized into five sections:
1. **[Vision Transformer (ViT)](model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE](model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
+1. **[ViTMSN](model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas.
1. **[Wav2Vec2](model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[Wav2Vec2-Conformer](model_doc/wav2vec2-conformer)** (from Facebook AI) released with the paper [FAIRSEQ S2T: Fast Speech-to-Text Modeling with FAIRSEQ](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Sravya Popuri, Dmytro Okhonko, Juan Pino.
1. **[Wav2Vec2Phoneme](model_doc/wav2vec2_phoneme)** (from Facebook AI) released with the paper [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition](https://arxiv.org/abs/2109.11680) by Qiantong Xu, Alexei Baevski, Michael Auli.
@@ -318,6 +319,7 @@ Flax), PyTorch, and/or TensorFlow.
| VisualBERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
| ViTMAE | ❌ | ❌ | ✅ | ✅ | ❌ |
+| ViTMSN | ❌ | ❌ | ✅ | ❌ | ❌ |
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/vit_msn.mdx b/docs/source/en/model_doc/vit_msn.mdx
new file mode 100644
index 00000000000000..07faed51e6cb57
--- /dev/null
+++ b/docs/source/en/model_doc/vit_msn.mdx
@@ -0,0 +1,64 @@
+
+
+# ViTMSN
+
+## Overview
+
+The ViTMSN model was proposed in [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes,
+Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. The paper presents a joint-embedding architecture to match the prototypes
+of masked patches with that of the unmasked patches. With this setup, their method yields excellent performance in the low-shot and extreme low-shot
+regimes.
+
+The abstract from the paper is the following:
+
+*We propose Masked Siamese Networks (MSN), a self-supervised learning framework for learning image representations. Our
+approach matches the representation of an image view containing randomly masked patches to the representation of the original
+unmasked image. This self-supervised pre-training strategy is particularly scalable when applied to Vision Transformers since only the
+unmasked patches are processed by the network. As a result, MSNs improve the scalability of joint-embedding architectures,
+while producing representations of a high semantic level that perform competitively on low-shot image classification. For instance,
+on ImageNet-1K, with only 5,000 annotated images, our base MSN model achieves 72.4% top-1 accuracy,
+and with 1% of ImageNet-1K labels, we achieve 75.7% top-1 accuracy, setting a new state-of-the-art for self-supervised learning on this benchmark.*
+
+Tips:
+
+- MSN (masked siamese networks) is a method for self-supervised pre-training of Vision Transformers (ViTs). The pre-training
+objective is to match the prototypes assigned to the unmasked views of the images to that of the masked views of the same images.
+- The authors have only released pre-trained weights of the backbone (ImageNet-1k pre-training). So, to use that on your own image classification dataset,
+use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMSNModel`]. Follow
+[this notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) for a detailed tutorial on fine-tuning.
+- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K
+labels when fine-tuned.
+
+
+
+
+ MSN architecture. Taken from the original paper.
+
+This model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/facebookresearch/msn).
+
+
+## ViTMSNConfig
+
+[[autodoc]] ViTMSNConfig
+
+
+## ViTMSNModel
+
+[[autodoc]] ViTMSNModel
+ - forward
+
+
+## ViTMSNForImageClassification
+
+[[autodoc]] ViTMSNForImageClassification
+ - forward
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 123c56bdd53267..50fb4d2c0b8a7e 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -359,6 +359,7 @@
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"],
+ "models.vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"],
"models.wav2vec2": [
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2Config",
@@ -1966,6 +1967,14 @@
"ViTMAEPreTrainedModel",
]
)
+ _import_structure["models.vit_msn"].extend(
+ [
+ "VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "ViTMSNModel",
+ "ViTMSNForImageClassification",
+ "ViTMSNPreTrainedModel",
+ ]
+ )
_import_structure["models.videomae"].extend(
[
"VIDEOMAE_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -3251,6 +3260,7 @@
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
+ from .models.vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2Config,
@@ -4586,6 +4596,12 @@
ViTMAEModel,
ViTMAEPreTrainedModel,
)
+ from .models.vit_msn import (
+ VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ViTMSNForImageClassification,
+ ViTMSNModel,
+ ViTMSNPreTrainedModel,
+ )
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 8a6622a9f35b6f..18c21cdf1865b3 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -150,6 +150,7 @@
visual_bert,
vit,
vit_mae,
+ vit_msn,
wav2vec2,
wav2vec2_conformer,
wav2vec2_phoneme,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index dce73cb3903656..39c48b217ff53f 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -145,6 +145,7 @@
("visual_bert", "VisualBertConfig"),
("vit", "ViTConfig"),
("vit_mae", "ViTMAEConfig"),
+ ("vit_msn", "ViTMSNConfig"),
("wav2vec2", "Wav2Vec2Config"),
("wav2vec2-conformer", "Wav2Vec2ConformerConfig"),
("wavlm", "WavLMConfig"),
@@ -266,6 +267,7 @@
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit", "VIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("vit_msn", "VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("wav2vec2-conformer", "WAV2VEC2_CONFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("xclip", "X_CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -418,6 +420,7 @@
("visual_bert", "VisualBERT"),
("vit", "ViT"),
("vit_mae", "ViTMAE"),
+ ("vit_msn", "ViTMSN"),
("wav2vec2", "Wav2Vec2"),
("wav2vec2-conformer", "Wav2Vec2-Conformer"),
("wav2vec2_phoneme", "Wav2Vec2Phoneme"),
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index 3f2265875f2a69..73fe1ad42ad195 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -74,6 +74,7 @@
("vilt", "ViltFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
+ ("vit_msn", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),
("xclip", "CLIPFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d7c5f1772f13bd..936e9c8bdc479c 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -139,6 +139,7 @@
("visual_bert", "VisualBertModel"),
("vit", "ViTModel"),
("vit_mae", "ViTMAEModel"),
+ ("vit_msn", "ViTMSNModel"),
("wav2vec2", "Wav2Vec2Model"),
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
("wavlm", "WavLMModel"),
@@ -367,6 +368,7 @@
("swinv2", "Swinv2ForImageClassification"),
("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"),
+ ("vit_msn", "ViTMSNForImageClassification"),
]
)
diff --git a/src/transformers/models/vit_msn/__init__.py b/src/transformers/models/vit_msn/__init__.py
new file mode 100644
index 00000000000000..832e730c5881c6
--- /dev/null
+++ b/src/transformers/models/vit_msn/__init__.py
@@ -0,0 +1,57 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
+
+
+_import_structure = {"configuration_vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_vit_msn"] = [
+ "VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "ViTMSNModel",
+ "ViTMSNForImageClassification",
+ "ViTMSNPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_vit_msn import (
+ VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ViTMSNForImageClassification,
+ ViTMSNModel,
+ ViTMSNPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/vit_msn/configuration_vit_msn.py b/src/transformers/models/vit_msn/configuration_vit_msn.py
new file mode 100644
index 00000000000000..057824e5d4e133
--- /dev/null
+++ b/src/transformers/models/vit_msn/configuration_vit_msn.py
@@ -0,0 +1,117 @@
+# coding=utf-8
+# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ViT MSN model configuration"""
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "sayakpaul/vit-msn-base": "https://huggingface.co/sayakpaul/vit-msn-base/resolve/main/config.json",
+ # See all ViT MSN models at https://huggingface.co/models?filter=vit_msn
+}
+
+
+class ViTMSNConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT
+ MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with
+ the defaults will yield a similar configuration to that of the ViT
+ [facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import ViTMSNModel, ViTMSNConfig
+
+ >>> # Initializing a ViT MSN vit-msn-base style configuration
+ >>> configuration = ViTConfig()
+
+ >>> # Initializing a model from the vit-msn-base style configuration
+ >>> model = ViTMSNModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "vit_msn"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-06,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
diff --git a/src/transformers/models/vit_msn/convert_msn_to_pytorch.py b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py
new file mode 100644
index 00000000000000..535f5f742d631f
--- /dev/null
+++ b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py
@@ -0,0 +1,236 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
+
+import argparse
+
+import torch
+from PIL import Image
+
+import requests
+from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel
+from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+
+torch.set_grad_enabled(False)
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config, base_model=False):
+ rename_keys = []
+ for i in range(config.num_hidden_layers):
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
+ rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
+ rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
+ rename_keys.append(
+ (f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")
+ )
+ rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
+ rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
+ rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
+ rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
+ rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
+ rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
+ rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
+
+ # projection layer + position embeddings
+ rename_keys.extend(
+ [
+ ("module.cls_token", "vit.embeddings.cls_token"),
+ ("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
+ ("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
+ ("module.pos_embed", "vit.embeddings.position_embeddings"),
+ ]
+ )
+
+ if base_model:
+ # layernorm + pooler
+ rename_keys.extend(
+ [
+ ("module.norm.weight", "layernorm.weight"),
+ ("module.norm.bias", "layernorm.bias"),
+ ]
+ )
+
+ # if just the base model, we should remove "vit" from all keys that start with "vit"
+ rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
+ else:
+ # layernorm + classification head
+ rename_keys.extend(
+ [
+ ("norm.weight", "vit.layernorm.weight"),
+ ("norm.bias", "vit.layernorm.bias"),
+ ("head.weight", "classifier.weight"),
+ ("head.bias", "classifier.bias"),
+ ]
+ )
+
+ return rename_keys
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config, base_model=False):
+ for i in range(config.num_hidden_layers):
+ if base_model:
+ prefix = ""
+ else:
+ prefix = "vit."
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
+ : config.hidden_size, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def remove_classification_head_(state_dict):
+ ignore_keys = ["head.weight", "head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def remove_projection_head(state_dict):
+ # projection head is used in the self-supervised pre-training in MSN,
+ # for downstream task it's not needed.
+ ignore_keys = [
+ "module.fc.fc1.weight",
+ "module.fc.fc1.bias",
+ "module.fc.bn1.weight",
+ "module.fc.bn1.bias",
+ "module.fc.bn1.running_mean",
+ "module.fc.bn1.running_var",
+ "module.fc.bn1.num_batches_tracked",
+ "module.fc.fc2.weight",
+ "module.fc.fc2.bias",
+ "module.fc.bn2.weight",
+ "module.fc.bn2.bias",
+ "module.fc.bn2.running_mean",
+ "module.fc.bn2.running_var",
+ "module.fc.bn2.num_batches_tracked",
+ "module.fc.fc3.weight",
+ "module.fc.fc3.bias",
+ ]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
+ config = ViTMSNConfig()
+ config.num_labels = 1000
+
+ if "s16" in checkpoint_url:
+ config.hidden_size = 384
+ config.intermediate_size = 1536
+ config.num_attention_heads = 6
+ elif "l16" in checkpoint_url:
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ config.hidden_dropout_prob = 0.1
+ elif "b4" in checkpoint_url:
+ config.patch_size = 4
+ elif "l7" in checkpoint_url:
+ config.patch_size = 7
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ config.hidden_dropout_prob = 0.1
+
+ model = ViTMSNModel(config)
+
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"]
+
+ feature_extractor = ViTFeatureExtractor(size=config.image_size)
+
+ remove_projection_head(state_dict)
+ rename_keys = create_rename_keys(config, base_model=True)
+
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config, base_model=True)
+
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+
+ image = Image.open(requests.get(url, stream=True).raw)
+ feature_extractor = ViTFeatureExtractor(
+ size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD
+ )
+ inputs = feature_extractor(images=image, return_tensors="pt")
+
+ # forward pass
+ torch.manual_seed(2)
+ outputs = model(**inputs)
+ last_hidden_state = outputs.last_hidden_state
+
+ # The following Colab Notebook was used to generate these outputs:
+ # https://colab.research.google.com/gist/sayakpaul/3672419a04f5997827503fd84079bdd1/scratchpad.ipynb
+ if "s16" in checkpoint_url:
+ expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]])
+ elif "b16" in checkpoint_url:
+ expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]])
+ elif "l16" in checkpoint_url:
+ expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]])
+ elif "b4" in checkpoint_url:
+ expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]])
+ else:
+ expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]])
+
+ # verify logits
+ assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4)
+
+ print(f"Saving model to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ print(f"Saving feature extractor to {pytorch_dump_folder_path}")
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar",
+ type=str,
+ help="URL of the checkpoint you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+
+ args = parser.parse_args()
+ convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py
new file mode 100644
index 00000000000000..314b3dbd5bfdbd
--- /dev/null
+++ b/src/transformers/models/vit_msn/modeling_vit_msn.py
@@ -0,0 +1,695 @@
+# coding=utf-8
+# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch ViT MSN (masked siamese network) model."""
+
+
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_vit_msn import ViTMSNConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "ViTMSNConfig"
+_CHECKPOINT_FOR_DOC = "sayakpaul/vit-msn-small"
+VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "sayakpaul/vit-msn-small",
+ # See all ViTMSN models at https://huggingface.co/models?filter=vit_msn
+]
+
+
+class ViTMSNEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = ViTMSNPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ patch_window_height = height // self.config.patch_size
+ patch_window_width = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(
+ patch_window_height / math.sqrt(num_positions),
+ patch_window_width / math.sqrt(num_positions),
+ ),
+ mode="bicubic",
+ align_corners=False,
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN
+class ViTMSNPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN
+class ViTMSNSelfAttention(nn.Module):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
+class ViTMSNSelfOutput(nn.Module):
+ """
+ The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN
+class ViTMSNAttention(nn.Module):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.attention = ViTMSNSelfAttention(config)
+ self.output = ViTMSNSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
+class ViTMSNIntermediate(nn.Module):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN
+class ViTMSNOutput(nn.Module):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN
+class ViTMSNLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ViTMSNAttention(config)
+ self.intermediate = ViTMSNIntermediate(config)
+ self.output = ViTMSNOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViTMSN, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN
+class ViTMSNEncoder(nn.Module):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ViTMSNPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ViTMSNConfig
+ base_model_prefix = "vit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
+ # when creating pre-training scripts.
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None:
+ if isinstance(module, ViTMSNEncoder):
+ module.gradient_checkpointing = value
+
+
+VIT_MSN_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VIT_MSN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ViTMSN Model outputting raw hidden-states without any specific head on top.",
+ VIT_MSN_START_DOCSTRING,
+)
+class ViTMSNModel(ViTMSNPreTrainedModel):
+ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = ViTMSNEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> ViTMSNPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, ViTMSNModel
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-msn-small")
+ >>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small")
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ if not return_dict:
+ head_outputs = (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Caution: We don't have the weights for the classification head yet. This class
+# is here for the users that are interested to fine-tune the base model (ViTMSNModel).
+@add_start_docstrings(
+ """
+ ViTMSN Model with an image classification head on top e.g. for ImageNet.
+ """,
+ VIT_MSN_START_DOCSTRING,
+)
+class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
+ def __init__(self, config: ViTMSNConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.vit = ViTMSNModel(config)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, ViTMSNForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-msn-small")
+ >>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_label = logits.argmax(-1).item()
+ >>> print(model.config.id2label[predicted_label])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output[:, 0, :])
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index b656cee9c89bdc..e9f1bae358f3ac 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -5150,6 +5150,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class ViTMSNForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViTMSNModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViTMSNPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/tests/models/vit_msn/__init__.py b/tests/models/vit_msn/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/vit_msn/test_modeling_vit_msn.py b/tests/models/vit_msn/test_modeling_vit_msn.py
new file mode 100644
index 00000000000000..b858da42b3d4e3
--- /dev/null
+++ b/tests/models/vit_msn/test_modeling_vit_msn.py
@@ -0,0 +1,239 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch ViTMSN model. """
+
+
+import inspect
+import unittest
+
+from transformers import ViTMSNConfig
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+from transformers.utils import cached_property, is_torch_available, is_vision_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import ViTMSNForImageClassification, ViTMSNModel
+ from transformers.models.vit_msn.modeling_vit_msn import VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ViTFeatureExtractor
+
+
+class ViTMAEModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ scope=None,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+
+ # in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches + 1
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return ViTMSNConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ initializer_range=self.initializer_range,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = ViTMSNModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = ViTMSNForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ print("Pixel and labels shape: {pixel_values.shape}, {labels.shape}")
+ print("Labels: {labels}")
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ # test greyscale images
+ config.num_channels = 1
+ model = ViTMSNForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values, labels = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class ViTMSNModelTest(ModelTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as ViTMAE does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
+
+ test_pruning = False
+ test_torchscript = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = ViTMAEModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=ViTMSNConfig, has_text_modality=False, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="ViTMAE does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = ViTMSNModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class ViTMSNModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_feature_extractor(self):
+ return ViTFeatureExtractor.from_pretrained("sayakpaul/vit-msn-small") if is_vision_available() else None
+
+ @slow
+ def test_inference_image_classification_head(self):
+ torch.manual_seed(2)
+ model = ViTMSNForImageClassification.from_pretrained("sayakpaul/vit-msn-small").to(torch_device)
+
+ feature_extractor = self.default_feature_extractor
+ image = prepare_img()
+ inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1000))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-0.0803, -0.4454, -0.2375]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index eee74d8940f595..3a7dfa55041d73 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -85,6 +85,7 @@ src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.p
src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py
+src/transformers/models/vit_msn/modeling_vit_msn.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py