forked from huggingface/transformers
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MSN (Masked Siamese Networks) for ViT (huggingface#18815)
* feat: modeling and conversion scripts for msn. * chore: change license year. * chore: remove unneeded modules. * feat: direct loading of state_dict from remote url. * fix: import paths. * add: rest of the files. * add and fix rest of the files. Co-authored-by: Niels <niels.rogge1@gmail.com> * chore: formatting. * code quality fix. * chore: remove pooler. * feat: add classification top. * fix: configuration object. * add: initial test cases (one failing). * fix: basemodeloutput. * add: caution on using the classification head. * add: rest of the model related files. * add: vit msn readme. * fix: copied from statement. * fix: dummy objects. * add: ViTMSNPreTrainedModel to inits. * fix: repo consistency. * minor change in the model doc. * fix: tests. * Empty-Commit * Update src/transformers/models/vit_msn/configuration_vit_msn.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address PR comments. * Update src/transformers/models/vit_msn/modeling_vit_msn.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * chore: put model in no_grad() and formatting. Co-authored-by: Niels <niels.rogge1@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
- Loading branch information
1 parent
3d257ce
commit c43afb4
Showing
20 changed files
with
1,464 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
<!--Copyright 2022 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# ViTMSN | ||
|
||
## Overview | ||
|
||
The ViTMSN model was proposed in [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, | ||
Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. The paper presents a joint-embedding architecture to match the prototypes | ||
of masked patches with that of the unmasked patches. With this setup, their method yields excellent performance in the low-shot and extreme low-shot | ||
regimes. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*We propose Masked Siamese Networks (MSN), a self-supervised learning framework for learning image representations. Our | ||
approach matches the representation of an image view containing randomly masked patches to the representation of the original | ||
unmasked image. This self-supervised pre-training strategy is particularly scalable when applied to Vision Transformers since only the | ||
unmasked patches are processed by the network. As a result, MSNs improve the scalability of joint-embedding architectures, | ||
while producing representations of a high semantic level that perform competitively on low-shot image classification. For instance, | ||
on ImageNet-1K, with only 5,000 annotated images, our base MSN model achieves 72.4% top-1 accuracy, | ||
and with 1% of ImageNet-1K labels, we achieve 75.7% top-1 accuracy, setting a new state-of-the-art for self-supervised learning on this benchmark.* | ||
|
||
Tips: | ||
|
||
- MSN (masked siamese networks) is a method for self-supervised pre-training of Vision Transformers (ViTs). The pre-training | ||
objective is to match the prototypes assigned to the unmasked views of the images to that of the masked views of the same images. | ||
- The authors have only released pre-trained weights of the backbone (ImageNet-1k pre-training). So, to use that on your own image classification dataset, | ||
use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMSNModel`]. Follow | ||
[this notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_classification.ipynb) for a detailed tutorial on fine-tuning. | ||
- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K | ||
labels when fine-tuned. | ||
|
||
|
||
<img src="https://i.ibb.co/W6PQMdC/Screenshot-2022-09-13-at-9-08-40-AM.png" alt="drawing" width="600"/> | ||
|
||
<small> MSN architecture. Taken from the <a href="https://arxiv.org/abs/2204.07141">original paper.</a> </small> | ||
|
||
This model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/facebookresearch/msn). | ||
|
||
|
||
## ViTMSNConfig | ||
|
||
[[autodoc]] ViTMSNConfig | ||
|
||
|
||
## ViTMSNModel | ||
|
||
[[autodoc]] ViTMSNModel | ||
- forward | ||
|
||
|
||
## ViTMSNForImageClassification | ||
|
||
[[autodoc]] ViTMSNForImageClassification | ||
- forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -150,6 +150,7 @@ | |
visual_bert, | ||
vit, | ||
vit_mae, | ||
vit_msn, | ||
wav2vec2, | ||
wav2vec2_conformer, | ||
wav2vec2_phoneme, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.