From b738b6fbc687dec05fcfc7b26003f93a1ecd219b Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Wed, 27 Jul 2022 20:44:47 +0530 Subject: [PATCH] Add swin transformer v2 (#17469) * Add files generated using transformer-cli add-new-model-like command * Add changes for swinv2 attention and forward method * Add fixes * Add modifications for weight conversion and remaining args in swin model * Add changes for patchmerging * Add changes for SwinV2selfattention * Update conversion script * Add final fixes for the swin_v2 model * Add changes for conversion script for pretrained window size case * Add pretrained window size value from config in SwinV2Encoder class * Make fixup * Add swinv2 to models_not_in_readme to utils/check_copies.py * Modify Swinv2v2 to Swin Transformer V2 * Remove copied from, to run make fixup command * Add updates to swinv2tf from main branch * Add pretrained_window_size to config, to make tests pass * Add modified weights from nandwalritik profile for swinv2 * Update model weights from swinv2 from nandwalritik profile * Add fix for build_pr_documentation CI fix * Add fixes for weight conversion * Add change to make input with padding work * Add fixes for test cases * Add few changes from swin to swinv2 to pass test cases * Remove tests for tensorflow as swinv2 for TF is not added yet * Overide test_pt_tf_model_equivalence function as TF implementation for swinv2 is not added yet * Add modeling_tf_swinv2 to _ignore_modules as test file is removed for this one right now. * Update docs url for swinv2 in README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Undo changes for check_repo * Update url in readme.md * Remove overrided function to test pt_tf_model_equivalence * Remove TF model imports for Swinv2 as its not implemented in this PR * Add changes for index.mdx * Add swinv2 papers link,abstract and contributors details * Rename cpb_mlp to continous_position_bias_mlp * Add tips for swinv2 model * Update src/transformers/models/swinv2/configuration_swinv2.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/swinv2/configuration_swinv2.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Fix indentation for docstring example in src/transformers/models/swinv2/configuration_swinv2.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update import order in src/transformers/models/swinv2/configuration_swinv2.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add copyright statements in weights conversion script. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Remove Swinv2 from models_not_in_readme * Reformat code * Remove TF implementation file for swinv2 * Update start docstring. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add changes for docstring * Update orgname for weights to microsoft * Remove to_2tuple function * Add copied from statements wherever applicable * Add copied from to Swinv2ForMaskedImageModelling class * Reformat code. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add unittest.skip(with reason.) for test_inputs_embeds test case. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add updates for test_modeling_swinv2.py * Add @unittest.skip() annotation for clarity to create_and_test_config_common_properties function * Add continuous_position_bias_mlp parameter to conversion script * Add test for testing masked_image_modelling for swinv2 * Update Swinv2 to Swin Transformer v2 in docs/source/en/model_doc/swinv2.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update Swinv2 to Swin Transformer v2 in docs/source/en/model_doc/swinv2.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/swinv2.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/swinv2.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add suggested changes * Add copied from to forward methods of Swinv2Stage and Swinv2Encoder * Add push_to_hub flag to weight conversion script * Change order or Swinv2DropPath class * Add id2label mapping for imagenet 21k * Add updated url for SwinV2 functions and classes used in implementation * Update input_feature dimensions format, mentioned in comments. Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> * Add suggested changes for modeling_swin2.py * Update docs * Remove create_and_test_config_common_properties function, as test_model_common_attributes is sufficient. * Fix indentation. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add changes for making Nit objects in code style * Add suggested changes * Add suggested changes for test_modelling_swinv2 * make fix-copies * Update docs/source/en/model_doc/swinv2.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 2 + docs/source/en/model_doc/swinv2.mdx | 47 + src/transformers/__init__.py | 18 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 3 + src/transformers/models/swinv2/__init__.py | 65 + .../models/swinv2/configuration_swinv2.py | 147 ++ .../swinv2/convert_swinv2_timm_to_pytorch.py | 219 +++ .../models/swinv2/modeling_swinv2.py | 1292 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 31 + tests/models/swinv2/__init__.py | 0 tests/models/swinv2/test_modeling_swinv2.py | 430 ++++++ 19 files changed, 2265 insertions(+) create mode 100644 docs/source/en/model_doc/swinv2.mdx create mode 100644 src/transformers/models/swinv2/__init__.py create mode 100644 src/transformers/models/swinv2/configuration_swinv2.py create mode 100644 src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py create mode 100644 src/transformers/models/swinv2/modeling_swinv2.py create mode 100644 tests/models/swinv2/__init__.py create mode 100644 tests/models/swinv2/test_modeling_swinv2.py diff --git a/README.md b/README.md index 6e220fb86bbf52..99c90a3916ddda 100644 --- a/README.md +++ b/README.md @@ -356,6 +356,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[Splinter](https://huggingface.co/docs/transformers/model_doc/splinter)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/main/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_ko.md b/README_ko.md index 74c6b2d2bbd19e..adfaefddf628ca 100644 --- a/README_ko.md +++ b/README_ko.md @@ -312,6 +312,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[Splinter](https://huggingface.co/docs/transformers/model_doc/splinter)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/main/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/README_zh-hans.md b/README_zh-hans.md index 5088b8ca63aaab..0e51441b407b44 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -336,6 +336,7 @@ conda install -c huggingface transformers 1. **[Splinter](https://huggingface.co/docs/transformers/model_doc/splinter)** (来自 Tel Aviv University) 伴随论文 [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) 由 Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy 发布。 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (来自 Berkeley) 伴随论文 [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) 由 Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer 发布。 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (来自 Microsoft) 伴随论文 [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) 由 Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo 发布。 +1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/main/model_doc/swinv2)** (来自 Microsoft) 伴随论文 [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) 由 Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo 发布。 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (来自 Google AI) 伴随论文 [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 67b5c4c076640a..1fbff9fa1741bd 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -348,6 +348,7 @@ conda install -c huggingface transformers 1. **[Splinter](https://huggingface.co/docs/transformers/model_doc/splinter)** (from Tel Aviv University) released with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin Transformer V2](https://huggingface.co/docs/transformers/main/model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/docs/transformers/model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 72bc4f2093cf8d..0e552d9dcc43db 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -378,6 +378,8 @@ title: SqueezeBERT - local: model_doc/swin title: Swin Transformer + - local: model_doc/swinv2 + title: Swin Transformer V2 - local: model_doc/t5 title: T5 - local: model_doc/t5v1.1 diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index c1e174c3f537cf..e8c3ed2928a7c6 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -154,6 +154,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[Splinter](model_doc/splinter)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. 1. **[SqueezeBERT](model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[Swin Transformer](model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. +1. **[Swin Transformer V2](model_doc/swinv2)** (from Microsoft) released with the paper [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. 1. **[T5](model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5v1.1](model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](model_doc/tapas)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. @@ -289,6 +290,7 @@ Flax), PyTorch, and/or TensorFlow. | Splinter | ✅ | ✅ | ✅ | ❌ | ❌ | | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | Swin Transformer | ❌ | ❌ | ✅ | ✅ | ❌ | +| Swin Transformer V2 | ❌ | ❌ | ✅ | ❌ | ❌ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/swinv2.mdx b/docs/source/en/model_doc/swinv2.mdx new file mode 100644 index 00000000000000..9f91a265ed100f --- /dev/null +++ b/docs/source/en/model_doc/swinv2.mdx @@ -0,0 +1,47 @@ + + +# Swin Transformer V2 + +## Overview + +The Swin Transformer V2 model was proposed in [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. + +The abstract from the paper is the following: + +*Large-scale NLP models have been shown to significantly improve the performance on language tasks with no signs of saturation. They also demonstrate amazing few-shot capabilities like that of human beings. This paper aims to explore large-scale models in computer vision. We tackle three major issues in training and application of large vision models, including training instability, resolution gaps between pre-training and fine-tuning, and hunger on labelled data. Three main techniques are proposed: 1) a residual-post-norm method combined with cosine attention to improve training stability; 2) A log-spaced continuous position bias method to effectively transfer models pre-trained using low-resolution images to downstream tasks with high-resolution inputs; 3) A self-supervised pre-training method, SimMIM, to reduce the needs of vast labeled images. Through these techniques, this paper successfully trained a 3 billion-parameter Swin Transformer V2 model, which is the largest dense vision model to date, and makes it capable of training with images of up to 1,536×1,536 resolution. It set new performance records on 4 representative vision tasks, including ImageNet-V2 image classification, COCO object detection, ADE20K semantic segmentation, and Kinetics-400 video action classification. Also note our training is much more efficient than that in Google's billion-level visual models, which consumes 40 times less labelled data and 40 times less training time.* + +Tips: +- One can use the [`AutoFeatureExtractor`] API to prepare images for the model. + +This model was contributed by [nandwalritik](https://huggingface.co/nandwalritik). +The original code can be found [here](https://github.com/microsoft/Swin-Transformer). + + +## Swinv2Config + +[[autodoc]] Swinv2Config + +## Swinv2Model + +[[autodoc]] Swinv2Model + - forward + +## Swinv2ForMaskedImageModeling + +[[autodoc]] Swinv2ForMaskedImageModeling + - forward + +## Swinv2ForImageClassification + +[[autodoc]] transformers.Swinv2ForImageClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6b2e7a3847e559..1c0d86e567617b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -313,6 +313,7 @@ "models.splinter": ["SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SplinterConfig", "SplinterTokenizer"], "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.swin": ["SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "SwinConfig"], + "models.swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], "models.tapex": ["TapexTokenizer"], @@ -1750,6 +1751,15 @@ "SwinPreTrainedModel", ] ) + _import_structure["models.swinv2"].extend( + [ + "SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Swinv2ForImageClassification", + "Swinv2ForMaskedImageModeling", + "Swinv2Model", + "Swinv2PreTrainedModel", + ] + ) _import_structure["models.t5"].extend( [ "T5_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3068,6 +3078,7 @@ from .models.splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig, SplinterTokenizer from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.swin import SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinConfig + from .models.swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer from .models.tapex import TapexTokenizer @@ -4265,6 +4276,13 @@ SwinModel, SwinPreTrainedModel, ) + from .models.swinv2 import ( + SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Swinv2ForImageClassification, + Swinv2ForMaskedImageModeling, + Swinv2Model, + Swinv2PreTrainedModel, + ) from .models.t5 import ( T5_PRETRAINED_MODEL_ARCHIVE_LIST, T5EncoderModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 663db063d68e78..1b81ce7d8fab1d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -127,6 +127,7 @@ splinter, squeezebert, swin, + swinv2, t5, tapas, tapex, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b8d5bbf223918e..13f69c024a9a0d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -123,6 +123,7 @@ ("splinter", "SplinterConfig"), ("squeezebert", "SqueezeBertConfig"), ("swin", "SwinConfig"), + ("swinv2", "Swinv2Config"), ("t5", "T5Config"), ("tapas", "TapasConfig"), ("trajectory_transformer", "TrajectoryTransformerConfig"), @@ -239,6 +240,7 @@ ("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("squeezebert", "SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("swinv2", "SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("t5", "T5_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -374,6 +376,7 @@ ("splinter", "Splinter"), ("squeezebert", "SqueezeBERT"), ("swin", "Swin Transformer"), + ("swinv2", "Swin Transformer V2"), ("t5", "T5"), ("t5v1.1", "T5v1.1"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 57def4b27396aa..8c4564e261c616 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -66,6 +66,7 @@ ("segformer", "SegformerFeatureExtractor"), ("speech_to_text", "Speech2TextFeatureExtractor"), ("swin", "ViTFeatureExtractor"), + ("swinv2", "ViTFeatureExtractor"), ("van", "ConvNextFeatureExtractor"), ("vilt", "ViltFeatureExtractor"), ("vit", "ViTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d443d71ca5f72a..ce2c4f9457b3d0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -119,6 +119,7 @@ ("splinter", "SplinterModel"), ("squeezebert", "SqueezeBertModel"), ("swin", "SwinModel"), + ("swinv2", "Swinv2Model"), ("t5", "T5Model"), ("tapas", "TapasModel"), ("trajectory_transformer", "TrajectoryTransformerModel"), @@ -309,6 +310,7 @@ [ ("deit", "DeiTForMaskedImageModeling"), ("swin", "SwinForMaskedImageModeling"), + ("swinv2", "Swinv2ForMaskedImageModeling"), ("vit", "ViTForMaskedImageModeling"), ] ) @@ -345,6 +347,7 @@ ("resnet", "ResNetForImageClassification"), ("segformer", "SegformerForImageClassification"), ("swin", "SwinForImageClassification"), + ("swinv2", "Swinv2ForImageClassification"), ("van", "VanForImageClassification"), ("vit", "ViTForImageClassification"), ] diff --git a/src/transformers/models/swinv2/__init__.py b/src/transformers/models/swinv2/__init__.py new file mode 100644 index 00000000000000..1cf259b8303e99 --- /dev/null +++ b/src/transformers/models/swinv2/__init__.py @@ -0,0 +1,65 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_swinv2": ["SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Swinv2Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_swinv2"] = [ + "SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Swinv2ForImageClassification", + "Swinv2ForMaskedImageModeling", + "Swinv2Model", + "Swinv2PreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_swinv2 import SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP, Swinv2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_swinv2 import ( + SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Swinv2ForImageClassification, + Swinv2ForMaskedImageModeling, + Swinv2Model, + Swinv2PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/swinv2/configuration_swinv2.py b/src/transformers/models/swinv2/configuration_swinv2.py new file mode 100644 index 00000000000000..f861be05fe1f85 --- /dev/null +++ b/src/transformers/models/swinv2/configuration_swinv2.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Swinv2 Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +SWINV2_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/swinv2_tiny_patch4_windows8_256": ( + "https://huggingface.co/microsoft/swinv2_tiny_patch4_windows8_256/resolve/main/config.json" + ), +} + + +class Swinv2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Swinv2Model`]. It is used to instantiate a Swin + Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2 + [microsoft/swinv2_tiny_patch4_windows8_256](https://huggingface.co/microsoft/swinv2_tiny_patch4_windows8_256) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + patch_norm (`bool`, *optional*, defaults to `True`): + Whether or not to add layer normalization after patch embedding. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + encoder_stride (`int`, `optional`, defaults to 32): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + + Example: + + ```python + >>> from transformers import Swinv2Config, Swinv2Model + + >>> # Initializing a Swinv2 microsoft/swinv2_tiny_patch4_windows8_256 style configuration + >>> configuration = Swinv2Config() + + >>> # Initializing a model from the microsoft/swinv2_tiny_patch4_windows8_256 style configuration + >>> model = Swinv2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "swinv2" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_stride=32, + **kwargs + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.path_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.encoder_stride = encoder_stride + # we set the hidden_size attribute in order to make Swinv2 work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.pretrained_window_sizes = (0, 0, 0, 0) diff --git a/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py b/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py new file mode 100644 index 00000000000000..148793e3043b84 --- /dev/null +++ b/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Swinv2 checkpoints from the timm library.""" + +import argparse +import json +from pathlib import Path + +import torch +from PIL import Image + +import requests +import timm +from huggingface_hub import hf_hub_download +from transformers import AutoFeatureExtractor, Swinv2Config, Swinv2ForImageClassification + + +def get_swinv2_config(swinv2_name): + config = Swinv2Config() + name_split = swinv2_name.split("_") + + model_size = name_split[1] + if "to" in name_split[3]: + img_size = int(name_split[3][-3:]) + else: + img_size = int(name_split[3]) + if "to" in name_split[2]: + window_size = int(name_split[2][-2:]) + else: + window_size = int(name_split[2][6:]) + + if model_size == "tiny": + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "small": + embed_dim = 96 + depths = (2, 2, 18, 2) + num_heads = (3, 6, 12, 24) + elif model_size == "base": + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + else: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + + if "to" in swinv2_name: + config.pretrained_window_sizes = (12, 12, 12, 6) + + if ("22k" in swinv2_name) and ("to" not in swinv2_name): + num_classes = 21841 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-22k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + else: + num_classes = 1000 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + config.image_size = img_size + config.num_labels = num_classes + config.embed_dim = embed_dim + config.depths = depths + config.num_heads = num_heads + config.window_size = window_size + + return config + + +def rename_key(name): + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "q_bias" in name: + name = name.replace("q_bias", "query.bias") + if "k_bias" in name: + name = name.replace("k_bias", "key.bias") + if "v_bias" in name: + name = name.replace("v_bias", "value.bias") + if "cpb_mlp" in name: + name = name.replace("cpb_mlp", "continuous_position_bias_mlp") + if name == "norm.weight": + name = "layernorm.weight" + if name == "norm.bias": + name = "layernorm.bias" + + if "head" in name: + name = name.replace("head", "classifier") + else: + name = "swinv2." + name + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "mask" in key: + continue + elif "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[1]) + block_num = int(key_split[3]) + dim = model.swinv2.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight" + ] = val[dim : dim * 2, :] + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias" + ] = val[:dim] + orig_state_dict[f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[ + dim : dim * 2 + ] + orig_state_dict[ + f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias" + ] = val[-dim:] + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_swinv2_checkpoint(swinv2_name, pytorch_dump_folder_path): + timm_model = timm.create_model(swinv2_name, pretrained=True) + timm_model.eval() + + config = get_swinv2_config(swinv2_name) + model = Swinv2ForImageClassification(config) + model.eval() + + new_state_dict = convert_state_dict(timm_model.state_dict(), model) + model.load_state_dict(new_state_dict) + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/{}".format(swinv2_name.replace("_", "-"))) + image = Image.open(requests.get(url, stream=True).raw) + inputs = feature_extractor(images=image, return_tensors="pt") + + timm_outs = timm_model(inputs["pixel_values"]) + hf_outs = model(**inputs).logits + + assert torch.allclose(timm_outs, hf_outs, atol=1e-3) + + print(f"Saving model {swinv2_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + model.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, swinv2_name), + organization="nandwalritik", + commit_message="Add model", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--swinv2_name", + default="swinv2_tiny_patch4_window8_256", + type=str, + help="Name of the Swinv2 timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_swinv2_checkpoint(args.swinv2_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py new file mode 100644 index 00000000000000..52f836d5b91d38 --- /dev/null +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -0,0 +1,1292 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Swinv2 Transformer model.""" + + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_swinv2 import Swinv2Config + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Swinv2Config" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/swinv2-tiny-patch4-window8-256" +_EXPECTED_OUTPUT_SHAPE = [1, 64, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/swinv2-tiny-patch4-window8-256" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/swinv2-tiny-patch4-window8-256", + # See all Swinv2 models at https://huggingface.co/models?filter=swinv2 +] + + +# drop_path, Swinv2PatchEmbeddings, Swinv2PatchMerging and Swinv2DropPath are from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2.py. + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2 +class Swinv2EncoderOutput(ModelOutput): + """ + Swinv2 encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->Swinv2 +class Swinv2ModelOutput(ModelOutput): + """ + Swinv2 model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2 +class Swinv2MaskedImageModelingOutput(ModelOutput): + """ + Swinv2 masked image model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2 +class Swinv2ImageClassifierOutput(ModelOutput): + """ + Swinv2 outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) + windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + return windows + + +# Copied from transformers.models.swin.modeling_swin.drop_path +def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swinv2 +class Swinv2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->Swinv2 +class Swinv2Embeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = Swinv2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor]: + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->Swinv2 +class Swinv2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +class Swinv2PatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # [batch_size, height/2 * width/2, 4*num_channels] + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C] + + input_feature = self.reduction(input_feature) + input_feature = self.norm(input_feature) + + return input_feature + + +class Swinv2SelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + self.pretrained_window_size = pretrained_window_size + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.continuous_position_bias_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = ( + torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # [1, 2*window_height - 1, 2*window_width - 1, 2] + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 + else: + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8) + ) + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index, persistent=False) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # cosine attention + attention_scores = F.normalize(query_layer, dim=-1) @ F.normalize(key_layer, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp() + attention_scores = attention_scores * logit_scale + relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view( + -1, self.num_attention_heads + ) + # [window_height*window_width,window_height*window_width,num_attention_heads] + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + # [num_attention_heads,window_height*window_width,window_height*window_width] + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in Swinv2Model forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2 +class Swinv2SelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Swinv2Attention(nn.Module): + def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0): + super().__init__() + self.self = Swinv2SelfAttention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.output = Swinv2SelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swinv2 +class Swinv2Intermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swinv2 +class Swinv2Output(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class Swinv2Layer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, shift_size=0, pretrained_window_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.set_shift_and_window_size(input_resolution) + self.attention = Swinv2Attention( + config=config, + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + pretrained_window_size=pretrained_window_size + if isinstance(pretrained_window_size, collections.abc.Iterable) + else (pretrained_window_size, pretrained_window_size), + ) + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.drop_path = Swinv2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.intermediate = Swinv2Intermediate(config, dim) + self.output = Swinv2Output(config, dim) + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + + def set_shift_and_window_size(self, input_resolution): + target_window_size = ( + self.window_size + if isinstance(self.window_size, collections.abc.Iterable) + else (self.window_size, self.window_size) + ) + target_shift_size = ( + self.shift_size + if isinstance(self.shift_size, collections.abc.Iterable) + else (self.shift_size, self.shift_size) + ) + self.window_size = ( + input_resolution[0] if input_resolution[0] <= target_window_size[0] else target_window_size[0] + ) + self.shift_size = ( + 0 + if input_resolution + <= ( + self.window_size + if isinstance(self.window_size, collections.abc.Iterable) + else (self.window_size, self.window_size) + ) + else target_shift_size[0] + ) + + def get_attn_mask(self, height, width): + if self.shift_size > 0: + # calculate attention mask for shifted window multihead self attention + img_mask = torch.zeros((1, height, width, 1)) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.set_shift_and_window_size(input_dimensions) + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + # pad hidden_states to multiples of window size + hidden_states = hidden_states.view(batch_size, height, width, channels) + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, width_pad) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + hidden_states = self.layernorm_before(attention_windows) + hidden_states = shortcut + self.drop_path(hidden_states) + + layer_output = self.intermediate(hidden_states) + layer_output = self.output(layer_output) + layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output)) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class Swinv2Stage(nn.Module): + def __init__( + self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0 + ): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + Swinv2Layer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + pretrained_window_size=pretrained_window_size, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + # Copied from transformers.models.swin.modeling_swin.SwinStage.forward + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(layer_outputs[0], input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class Swinv2Encoder(nn.Module): + def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + if self.config.pretrained_window_sizes is not None: + pretrained_window_sizes = config.pretrained_window_sizes + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + self.layers = nn.ModuleList( + [ + Swinv2Stage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None, + pretrained_window_size=pretrained_window_sizes[i_layer], + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + # Copied from transformers.models.swin.modeling_swin.SwinEncoder.forward with SwinEncoderOutput->Swinv2EncoderOutput + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, Swinv2EncoderOutput]: + all_input_dimensions = () + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + ) + else: + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[1] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return Swinv2EncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->Swinv2,swin->swinv2 +class Swinv2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Swinv2Config + base_model_prefix = "swinv2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, Swinv2Encoder): + module.gradient_checkpointing = value + + +SWINV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Swinv2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWINV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top.", + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinModel with SWIN->SWINV2,Swin->Swinv2 +class Swinv2Model(Swinv2PreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token) + self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Swinv2ModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2ModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return Swinv2ModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + "Swinv2 Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with SWIN->SWINV2,Swin->Swinv2,swin->swinv2,224->256,window7->window8 +class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True) + + num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1 + ), + nn.PixelShuffle(config.encoder_stride), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Swinv2MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoFeatureExtractor, Swinv2ForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits + >>> list(reconstructed_pixel_values.shape) + [1, 3, 256, 256] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swinv2( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output.transpose(1, 2) + batch_size, num_channels, sequence_length = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[2:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return Swinv2MaskedImageModelingOutput( + loss=masked_im_loss, + logits=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + SWINV2_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with SWIN->SWINV2,Swin->Swinv2,swin->swinv2 +class Swinv2ForImageClassification(Swinv2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.swinv2 = Swinv2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWINV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=Swinv2ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Swinv2ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.swinv2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return Swinv2ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fb32886659a975..c08610921f936d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4469,6 +4469,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class Swinv2ForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Swinv2ForMaskedImageModeling(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Swinv2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Swinv2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + T5_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/swinv2/__init__.py b/tests/models/swinv2/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/swinv2/test_modeling_swinv2.py b/tests/models/swinv2/test_modeling_swinv2.py new file mode 100644 index 00000000000000..13a39b139c81aa --- /dev/null +++ b/tests/models/swinv2/test_modeling_swinv2.py @@ -0,0 +1,430 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Swinv2 model. """ +import collections +import inspect +import unittest + +from transformers import Swinv2Config +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import Swinv2ForImageClassification, Swinv2ForMaskedImageModeling, Swinv2Model + from transformers.models.swinv2.modeling_swinv2 import SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class Swinv2ModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=2, + num_channels=3, + embed_dim=16, + depths=[1, 2, 1], + num_heads=[2, 2, 4], + window_size=2, + mlp_ratio=2.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + is_training=True, + scope=None, + use_labels=True, + type_sequence_label_size=10, + encoder_stride=8, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.patch_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.is_training = is_training + self.scope = scope + self.use_labels = use_labels + self.type_sequence_label_size = type_sequence_label_size + self.encoder_stride = encoder_stride + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return Swinv2Config( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + embed_dim=self.embed_dim, + depths=self.depths, + num_heads=self.num_heads, + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + drop_path_rate=self.drop_path_rate, + hidden_act=self.hidden_act, + use_absolute_embeddings=self.use_absolute_embeddings, + path_norm=self.patch_norm, + layer_norm_eps=self.layer_norm_eps, + initializer_range=self.initializer_range, + encoder_stride=self.encoder_stride, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = Swinv2Model(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1)) + expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): + model = Swinv2ForMaskedImageModeling(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + ) + + # test greyscale images + config.num_channels = 1 + model = Swinv2ForMaskedImageModeling(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = Swinv2ForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class Swinv2ModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + (Swinv2Model, Swinv2ForImageClassification, Swinv2ForMaskedImageModeling) if is_torch_available() else () + ) + + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = Swinv2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Swinv2Config, embed_dim=37) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Swinv2 does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + expected_num_attentions = len(self.model_tester.depths) + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + window_size_squared = config.window_size**2 + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + # also another +1 for reshaped_hidden_states + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) + + def check_hidden_states_output(self, inputs_dict, config, model_class, image_size): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # Swinv2 has a different seq_length + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + def test_hidden_states_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + def test_hidden_states_output_with_padding(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.patch_size = 3 + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) + padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in SWINV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = Swinv2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if "embeddings" not in name and "logit_scale" not in name and param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + +@require_vision +@require_torch +class Swinv2ModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + AutoFeatureExtractor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256") + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = Swinv2ForImageClassification.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to( + torch_device + ) + feature_extractor = self.default_feature_extractor + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))