From b109bc615f345a3ea4172e284e29628065c74d4f Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 18 May 2022 17:47:18 +0200 Subject: [PATCH] Add CvT (#17299) * Adding cvt files * Adding cvt files * changes in init file * Adding cvt files * changes in init file * Style fixes * Address comments from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Format lists in docstring * Fix copies * Apply suggestion from code review Co-authored-by: AnugunjNaman Co-authored-by: Ayushman Singh Co-authored-by: Niels Rogge Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 4 +- docs/source/en/model_doc/cvt.mdx | 53 ++ docs/source/en/serialization.mdx | 2 +- src/transformers/__init__.py | 16 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/cvt/__init__.py | 61 ++ .../models/cvt/configuration_cvt.py | 147 ++++ ..._original_pytorch_checkpoint_to_pytorch.py | 349 +++++++++ src/transformers/models/cvt/modeling_cvt.py | 735 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 24 + tests/models/cvt/__init__.py | 0 tests/models/cvt/test_modeling_cvt.py | 278 +++++++ utils/documentation_tests.txt | 1 + 21 files changed, 1681 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/cvt.mdx create mode 100644 src/transformers/models/cvt/__init__.py create mode 100644 src/transformers/models/cvt/configuration_cvt.py create mode 100644 src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/cvt/modeling_cvt.py create mode 100644 tests/models/cvt/__init__.py create mode 100644 tests/models/cvt/test_modeling_cvt.py diff --git a/README.md b/README.md index 12453a52bc5a4..c046393239091 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. +1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/README_ko.md b/README_ko.md index 198c806f36db5..aab7a7c4bc2d6 100644 --- a/README_ko.md +++ b/README_ko.md @@ -230,6 +230,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[ConvNeXT](https://huggingface.co/docs/transformers/main/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. 1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. +1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/README_zh-hans.md b/README_zh-hans.md index 4a697fbcf25c9..7031fd3570bf1 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -254,6 +254,7 @@ conda install -c huggingface transformers 1. **[ConvNeXT](https://huggingface.co/docs/transformers/main/model_doc/convnext)** (来自 Facebook AI) 伴随论文 [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) 由 Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie 发布。 1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (来自 Tsinghua University) 伴随论文 [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) 由 Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun 发布。 1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (来自 Salesforce) 伴随论文 [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) 由 Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher 发布。 +1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (来自 Microsoft) 伴随论文 [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) 由 Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher 发布。 1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (来自 Facebook) 伴随论文 [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) 由 Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli 发布。 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 5cce396ff53f8..5971ab404917f 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -266,6 +266,7 @@ conda install -c huggingface transformers 1. **[ConvNeXT](https://huggingface.co/docs/transformers/main/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. 1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. +1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[Data2Vec](https://huggingface.co/docs/transformers/main/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli. 1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 36580ecb2fcf3..cb67299cff4d4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -192,6 +192,8 @@ title: CPM - local: model_doc/ctrl title: CTRL + - local: model_doc/cvt + title: CvT - local: model_doc/data2vec title: Data2Vec - local: model_doc/deberta diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 176e0f2de889e..35d2a99440e79 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -72,6 +72,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[ConvBERT](model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[CPM](model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CTRL](model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. +1. **[CvT](model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[Data2Vec](model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli. 1. **[DeBERTa](model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. @@ -192,6 +193,7 @@ Flax), PyTorch, and/or TensorFlow. | ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ConvNext | ❌ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | +| CvT | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecAudio | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecText | ❌ | ❌ | ✅ | ❌ | ❌ | | Data2VecVision | ❌ | ❌ | ✅ | ✅ | ❌ | @@ -284,4 +286,4 @@ Flax), PyTorch, and/or TensorFlow. | YOLOS | ❌ | ❌ | ✅ | ❌ | ❌ | | YOSO | ❌ | ❌ | ✅ | ❌ | ❌ | - + \ No newline at end of file diff --git a/docs/source/en/model_doc/cvt.mdx b/docs/source/en/model_doc/cvt.mdx new file mode 100644 index 0000000000000..84be7e39a5507 --- /dev/null +++ b/docs/source/en/model_doc/cvt.mdx @@ -0,0 +1,53 @@ + + +# Convolutional Vision Transformer (CvT) + +## Overview + +The CvT model was proposed in [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan and Lei Zhang. The Convolutional vision Transformer (CvT) improves the [Vision Transformer (ViT)](vit) in performance and efficiency by introducing convolutions into ViT to yield the best of both designs. + +The abstract from the paper is the following: + +*We present in this paper a new architecture, named Convolutional vision Transformer (CvT), that improves Vision Transformer (ViT) +in performance and efficiency by introducing convolutions into ViT to yield the best of both designs. This is accomplished through +two primary modifications: a hierarchy of Transformers containing a new convolutional token embedding, and a convolutional Transformer +block leveraging a convolutional projection. These changes introduce desirable properties of convolutional neural networks (CNNs) +to the ViT architecture (\ie shift, scale, and distortion invariance) while maintaining the merits of Transformers (\ie dynamic attention, +global context, and better generalization). We validate CvT by conducting extensive experiments, showing that this approach achieves +state-of-the-art performance over other Vision Transformers and ResNets on ImageNet-1k, with fewer parameters and lower FLOPs. In addition, +performance gains are maintained when pretrained on larger datasets (\eg ImageNet-22k) and fine-tuned to downstream tasks. Pre-trained on +ImageNet-22k, our CvT-W24 obtains a top-1 accuracy of 87.7\% on the ImageNet-1k val set. Finally, our results show that the positional encoding, +a crucial component in existing Vision Transformers, can be safely removed in our model, simplifying the design for higher resolution vision tasks.* + +Tips: + +- CvT models are regular Vision Transformers, but trained with convolutions. They outperform the [original model (ViT)](vit) when fine-tuned on ImageNet-1K and CIFAR-100. +- You can check out demo notebooks regarding inference as well as fine-tuning on custom data [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer) (you can just replace [`ViTFeatureExtractor`] by [`AutoFeatureExtractor`] and [`ViTForImageClassification`] by [`CvtForImageClassification`]). +- The available checkpoints are either (1) pre-trained on [ImageNet-22k](http://www.image-net.org/) (a collection of 14 million images and 22k classes) only, (2) also fine-tuned on ImageNet-22k or (3) also fine-tuned on [ImageNet-1k](http://www.image-net.org/challenges/LSVRC/2012/) (also referred to as ILSVRC 2012, a collection of 1.3 million + images and 1,000 classes). + +This model was contributed by [anugunj](https://huggingface.co/anugunj). The original code can be found [here](https://github.com/microsoft/CvT). + +## CvtConfig + +[[autodoc]] CvtConfig + +## CvtModel + +[[autodoc]] CvtModel + - forward + +## CvtForImageClassification + +[[autodoc]] CvtForImageClassification + - forward diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 4ae35a96aebc0..2bb449240bb05 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -667,4 +667,4 @@ torch.neuron.trace(model, [token_tensor, segments_tensors]) This change enables Neuron SDK to trace the model and optimize it to run in Inf1 instances. To learn more about AWS Neuron SDK features, tools, example tutorials and latest updates, -please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html). +please see the [AWS NeuronSDK documentation](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/index.html). \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9f28d18fdf6f2..aff37abbec6aa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -171,6 +171,7 @@ "models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"], "models.cpm": [], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], + "models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"], "models.data2vec": [ "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -926,6 +927,14 @@ "CTRLPreTrainedModel", ] ) + _import_structure["models.cvt"].extend( + [ + "CVT_PRETRAINED_MODEL_ARCHIVE_LIST", + "CvtForImageClassification", + "CvtModel", + "CvtPreTrainedModel", + ] + ) _import_structure["models.data2vec"].extend( [ "DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2694,6 +2703,7 @@ from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer + from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig from .models.data2vec import ( DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -3345,6 +3355,12 @@ CTRLModel, CTRLPreTrainedModel, ) + from .models.cvt import ( + CVT_PRETRAINED_MODEL_ARCHIVE_LIST, + CvtForImageClassification, + CvtModel, + CvtPreTrainedModel, + ) from .models.data2vec import ( DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST, DATA2VEC_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 90ce4b1e9b3c4..66910e3e0a53b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -40,6 +40,7 @@ convnext, cpm, ctrl, + cvt, data2vec, deberta, deberta_v2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3013f8d87d782..aa4b64fa7015c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -44,6 +44,7 @@ ("convbert", "ConvBertConfig"), ("convnext", "ConvNextConfig"), ("ctrl", "CTRLConfig"), + ("cvt", "CvtConfig"), ("data2vec-audio", "Data2VecAudioConfig"), ("data2vec-text", "Data2VecTextConfig"), ("data2vec-vision", "Data2VecVisionConfig"), @@ -156,6 +157,7 @@ ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("cvt", "CVT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("data2vec-text", "DATA2VEC_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("data2vec-vision", "DATA2VEC_VISION_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -261,6 +263,7 @@ ("convnext", "ConvNext"), ("cpm", "CPM"), ("ctrl", "CTRL"), + ("cvt", "CvT"), ("data2vec-audio", "Data2VecAudio"), ("data2vec-text", "Data2VecText"), ("data2vec-vision", "Data2VecVision"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index e133a3ada7d8a..5ba7b1228544d 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -40,6 +40,7 @@ ("beit", "BeitFeatureExtractor"), ("clip", "CLIPFeatureExtractor"), ("convnext", "ConvNextFeatureExtractor"), + ("cvt", "ConvNextFeatureExtractor"), ("data2vec-audio", "Wav2Vec2FeatureExtractor"), ("data2vec-vision", "BeitFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index dd8324400f5f8..1e62a4ab8ed20 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -43,6 +43,7 @@ ("convbert", "ConvBertModel"), ("convnext", "ConvNextModel"), ("ctrl", "CTRLModel"), + ("cvt", "CvtModel"), ("data2vec-audio", "Data2VecAudioModel"), ("data2vec-text", "Data2VecTextModel"), ("data2vec-vision", "Data2VecVisionModel"), @@ -299,6 +300,7 @@ # Model for Image Classification mapping ("beit", "BeitForImageClassification"), ("convnext", "ConvNextForImageClassification"), + ("cvt", "CvtForImageClassification"), ("data2vec-vision", "Data2VecVisionForImageClassification"), ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), ("imagegpt", "ImageGPTForImageClassification"), diff --git a/src/transformers/models/cvt/__init__.py b/src/transformers/models/cvt/__init__.py new file mode 100644 index 0000000000000..5279f89f21584 --- /dev/null +++ b/src/transformers/models/cvt/__init__.py @@ -0,0 +1,61 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_cvt"] = [ + "CVT_PRETRAINED_MODEL_ARCHIVE_LIST", + "CvtForImageClassification", + "CvtModel", + "CvtPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_cvt import ( + CVT_PRETRAINED_MODEL_ARCHIVE_LIST, + CvtForImageClassification, + CvtModel, + CvtPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/cvt/configuration_cvt.py b/src/transformers/models/cvt/configuration_cvt.py new file mode 100644 index 0000000000000..e1e633e73b57b --- /dev/null +++ b/src/transformers/models/cvt/configuration_cvt.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" CvT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +CVT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/cvt-13": "https://huggingface.co/microsoft/cvt-13/resolve/main/config.json", + # See all Cvt models at https://huggingface.co/models?filter=cvt +} + + +class CvtConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the CvT + [microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`): + The kernel size of each encoder's patch embedding. + patch_stride (`List[int]`, *optional*, defaults to `[4, 2, 2]`): + The stride size of each encoder's patch embedding. + patch_padding (`List[int]`, *optional*, defaults to `[2, 1, 1]`): + The padding size of each encoder's patch embedding. + embed_dim (`List[int]`, *optional*, defaults to `[64, 192, 384]`): + Dimension of each of the encoder blocks. + num_heads (`List[int]`, *optional*, defaults to `[1, 3, 6]`): + Number of attention heads for each attention layer in each block of the Transformer encoder. + depth (`List[int]`, *optional*, defaults to `[1, 2, 10]`): + The number of layers in each encoder block. + mlp_ratios (`List[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`): + Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the + encoder blocks. + attention_drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`): + The dropout ratio for the attention probabilities. + drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`): + The dropout ratio for the patch embeddings probabilities. + drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`): + The dropout probability for stochastic depth, used in the blocks of the Transformer encoder. + qkv_bias (`List[bool]`, *optional*, defaults to `[True, True, True]`): + The bias bool for query, key and value in attentions + cls_token (`List[bool]`, *optional*, defaults to `[False, False, True]`): + Whether or not to add a classification token to the output of each of the last 3 stages. + qkv_projection_method (`List[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`): + The projection method for query, key and value Default is depth-wise convolutions with batch norm. For + Linear projection use "avg". + kernel_qkv (`List[int]`, *optional*, defaults to `[3, 3, 3]`): + The kernel size for query, key and value in attention layer + padding_kv (`List[int]`, *optional*, defaults to `[1, 1, 1]`): + The padding size for key and value in attention layer + stride_kv (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + The stride size for key and value in attention layer + padding_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`): + The padding size for query in attention layer + stride_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`): + The stride size for query in attention layer + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import CvtModel, CvtConfig + + >>> # Initializing a Cvt msft/cvt style configuration + >>> configuration = CvtConfig() + + >>> # Initializing a model from the msft/cvt style configuration + >>> model = CvtModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "cvt" + + def __init__( + self, + num_channels=3, + patch_sizes=[7, 3, 3], + patch_stride=[4, 2, 2], + patch_padding=[2, 1, 1], + embed_dim=[64, 192, 384], + num_heads=[1, 3, 6], + depth=[1, 2, 10], + mlp_ratio=[4.0, 4.0, 4.0], + attention_drop_rate=[0.0, 0.0, 0.0], + drop_rate=[0.0, 0.0, 0.0], + drop_path_rate=[0.0, 0.0, 0.1], + qkv_bias=[True, True, True], + cls_token=[False, False, True], + qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"], + kernel_qkv=[3, 3, 3], + padding_kv=[1, 1, 1], + stride_kv=[2, 2, 2], + padding_q=[1, 1, 1], + stride_q=[1, 1, 1], + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs + ): + super().__init__(**kwargs) + self.num_channels = num_channels + self.patch_sizes = patch_sizes + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.embed_dim = embed_dim + self.num_heads = num_heads + self.depth = depth + self.mlp_ratio = mlp_ratio + self.attention_drop_rate = attention_drop_rate + self.drop_rate = drop_rate + self.drop_path_rate = drop_path_rate + self.qkv_bias = qkv_bias + self.cls_token = cls_token + self.qkv_projection_method = qkv_projection_method + self.kernel_qkv = kernel_qkv + self.padding_kv = padding_kv + self.stride_kv = stride_kv + self.padding_q = padding_q + self.stride_q = stride_q + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps diff --git a/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000..ae0112ec12588 --- /dev/null +++ b/src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,349 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert CvT checkpoints from the original repository. + +URL: https://github.com/microsoft/CvT""" + + +import argparse +import json +from collections import OrderedDict + +import torch + +from huggingface_hub import cached_download, hf_hub_url +from transformers import AutoFeatureExtractor, CvtConfig, CvtForImageClassification + + +def embeddings(idx): + """ + The function helps in renaming embedding layer weights. + + Args: + idx: stage number in original model + """ + embed = [] + embed.append( + ( + f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight", + f"stage{idx}.patch_embed.proj.weight", + ) + ) + embed.append( + ( + f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias", + f"stage{idx}.patch_embed.proj.bias", + ) + ) + embed.append( + ( + f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight", + f"stage{idx}.patch_embed.norm.weight", + ) + ) + embed.append( + ( + f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias", + f"stage{idx}.patch_embed.norm.bias", + ) + ) + return embed + + +def attention(idx, cnt): + """ + The function helps in renaming attention block layers weights. + + Args: + idx: stage number in original model + cnt: count of blocks in each stage + """ + attention_weights = [] + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked", + f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight", + f"stage{idx}.blocks.{cnt}.attn.proj_q.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias", + f"stage{idx}.blocks.{cnt}.attn.proj_q.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight", + f"stage{idx}.blocks.{cnt}.attn.proj_k.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias", + f"stage{idx}.blocks.{cnt}.attn.proj_k.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight", + f"stage{idx}.blocks.{cnt}.attn.proj_v.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias", + f"stage{idx}.blocks.{cnt}.attn.proj_v.bias", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight", + f"stage{idx}.blocks.{cnt}.attn.proj.weight", + ) + ) + attention_weights.append( + ( + f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias", + f"stage{idx}.blocks.{cnt}.attn.proj.bias", + ) + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight") + ) + attention_weights.append( + (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias") + ) + return attention_weights + + +def cls_token(idx): + """ + Function helps in renaming cls_token weights + """ + token = [] + token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token")) + return token + + +def final(): + """ + Function helps in renaming final classification layer + """ + head = [] + head.append(("layernorm.weight", "norm.weight")) + head.append(("layernorm.bias", "norm.bias")) + head.append(("classifier.weight", "head.weight")) + head.append(("classifier.bias", "head.bias")) + return head + + +def convert_cvt_checkpoint(cvt_file, pytorch_dump_folder): + """ + Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint + """ + img_labels_file = "imagenet-1k-id2label.json" + num_labels = 1000 + + repo_id = "datasets/huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(cached_download(hf_hub_url(repo_id, img_labels_file)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id) + + # For depth size 13 (13 = 1+2+10) + if cvt_file.rsplit("/", 1)[-1][4:6] == "13": + config.depth = [1, 2, 10] + + # For depth size 21 (21 = 1+4+16) + elif cvt_file.rsplit("/", 1)[-1][4:6] == "21": + config.depth = [1, 4, 16] + + # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20) + else: + config.depth = [2, 2, 20] + config.num_heads = [3, 12, 16] + config.embed_dim = [192, 768, 1024] + + model = CvtForImageClassification(config) + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k") + original_weights = torch.load(cvt_file, map_location=torch.device("cpu")) + + huggingface_weights = OrderedDict() + list_of_state_dict = [] + + for idx in range(config.num_stages): + if config.cls_token[idx]: + list_of_state_dict = list_of_state_dict + cls_token(idx) + list_of_state_dict = list_of_state_dict + embeddings(idx) + for cnt in range(config.depth[idx]): + list_of_state_dict = list_of_state_dict + attention(idx, cnt) + + list_of_state_dict = list_of_state_dict + final() + for gg in list_of_state_dict: + print(gg) + for i in range(len(list_of_state_dict)): + huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]] + + model.load_state_dict(huggingface_weights) + model.save_pretrained(pytorch_dump_folder) + feature_extractor.save_pretrained(pytorch_dump_folder) + + +# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cvt_name", + default="cvt-13", + type=str, + help="Name of the cvt model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_cvt_checkpoint(args.cvt_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py new file mode 100644 index 0000000000000..154ad52faa1af --- /dev/null +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -0,0 +1,735 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch CvT model.""" + + +import collections.abc +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_cvt import CvtConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "CvtConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/cvt-13" +_EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +CVT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/cvt-13", + "microsoft/cvt-13-384-1k", + "microsoft/cvt-13-384-22k", + "microsoft/cvt-21", + "microsoft/cvt-21-384-1k", + "microsoft/cvt-21-384-22k", + # See all Cvt models at https://huggingface.co/models?filter=cvt +] + + +@dataclass +class BaseModelOutputWithCLSToken(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`): + Classification token at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + cls_token_value: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.convnext.modeling_convnext.drop_path +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the + DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop + Connect' is a different form of dropout in a separate paper... See discussion: + https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and + argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath +class CvtDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + +class CvtEmbeddings(nn.Module): + """ + Construct the CvT embeddings. + """ + + def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate): + super().__init__() + self.convolution_embeddings = CvtConvEmbeddings( + patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding + ) + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, pixel_values): + hidden_state = self.convolution_embeddings(pixel_values) + hidden_state = self.dropout(hidden_state) + return hidden_state + + +class CvtConvEmbeddings(nn.Module): + """ + Image to Conv Embedding. + """ + + def __init__(self, patch_size, num_channels, embed_dim, stride, padding): + super().__init__() + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + self.patch_size = patch_size + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) + self.normalization = nn.LayerNorm(embed_dim) + + def forward(self, pixel_values): + pixel_values = self.projection(pixel_values) + batch_size, num_channels, height, width = pixel_values.shape + hidden_size = height * width + # rearrange "b c h w -> b (h w) c" + pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1) + if self.normalization: + pixel_values = self.normalization(pixel_values) + # rearrange "b (h w) c" -> b c h w" + pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width) + return pixel_values + + +class CvtSelfAttentionConvProjection(nn.Module): + def __init__(self, embed_dim, kernel_size, padding, stride): + super().__init__() + self.convolution = nn.Conv2d( + embed_dim, + embed_dim, + kernel_size=kernel_size, + padding=padding, + stride=stride, + bias=False, + groups=embed_dim, + ) + self.normalization = nn.BatchNorm2d(embed_dim) + + def forward(self, hidden_state): + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class CvtSelfAttentionLinearProjection(nn.Module): + def forward(self, hidden_state): + batch_size, num_channels, height, width = hidden_state.shape + hidden_size = height * width + # rearrange " b c h w -> b (h w) c" + hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1) + return hidden_state + + +class CvtSelfAttentionProjection(nn.Module): + def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"): + super().__init__() + if projection_method == "dw_bn": + self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride) + self.linear_projection = CvtSelfAttentionLinearProjection() + + def forward(self, hidden_state): + hidden_state = self.convolution_projection(hidden_state) + hidden_state = self.linear_projection(hidden_state) + return hidden_state + + +class CvtSelfAttention(nn.Module): + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + with_cls_token=True, + **kwargs + ): + super().__init__() + self.scale = embed_dim**-0.5 + self.with_cls_token = with_cls_token + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.convolution_projection_query = CvtSelfAttentionProjection( + embed_dim, + kernel_size, + padding_q, + stride_q, + projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method, + ) + self.convolution_projection_key = CvtSelfAttentionProjection( + embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method + ) + self.convolution_projection_value = CvtSelfAttentionProjection( + embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method + ) + + self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + + self.dropout = nn.Dropout(attention_drop_rate) + + def rearrange_for_multi_head_attention(self, hidden_state): + batch_size, hidden_size, _ = hidden_state.shape + head_dim = self.embed_dim // self.num_heads + # rearrange 'b t (h d) -> b h t d' + return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3) + + def forward(self, hidden_state, height, width): + if self.with_cls_token: + cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1) + batch_size, hidden_size, num_channels = hidden_state.shape + # rearrange "b (h w) c -> b c h w" + hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width) + + key = self.convolution_projection_key(hidden_state) + query = self.convolution_projection_query(hidden_state) + value = self.convolution_projection_value(hidden_state) + + if self.with_cls_token: + query = torch.cat((cls_token, query), dim=1) + key = torch.cat((cls_token, key), dim=1) + value = torch.cat((cls_token, value), dim=1) + + head_dim = self.embed_dim // self.num_heads + + query = self.rearrange_for_multi_head_attention(self.projection_query(query)) + key = self.rearrange_for_multi_head_attention(self.projection_key(key)) + value = self.rearrange_for_multi_head_attention(self.projection_value(value)) + + attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale + attention_probs = torch.nn.functional.softmax(attention_score, dim=-1) + attention_probs = self.dropout(attention_probs) + + context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value]) + # rearrange"b h t d -> b t (h d)" + _, _, hidden_size, _ = context.shape + context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim) + return context + + +class CvtSelfOutput(nn.Module): + """ + The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, embed_dim, drop_rate): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_rate) + + def forward(self, hidden_state, input_tensor): + hidden_state = self.dense(hidden_state) + hidden_state = self.dropout(hidden_state) + return hidden_state + + +class CvtAttention(nn.Module): + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + with_cls_token=True, + ): + super().__init__() + self.attention = CvtSelfAttention( + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + with_cls_token, + ) + self.output = CvtSelfOutput(embed_dim, drop_rate) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_state, height, width): + self_output = self.attention(hidden_state, height, width) + attention_output = self.output(self_output, hidden_state) + return attention_output + + +class CvtIntermediate(nn.Module): + def __init__(self, embed_dim, mlp_ratio): + super().__init__() + self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) + self.activation = nn.GELU() + + def forward(self, hidden_state): + hidden_state = self.dense(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class CvtOutput(nn.Module): + def __init__(self, embed_dim, mlp_ratio, drop_rate): + super().__init__() + self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim) + self.dropout = nn.Dropout(drop_rate) + + def forward(self, hidden_state, input_tensor): + hidden_state = self.dense(hidden_state) + hidden_state = self.dropout(hidden_state) + hidden_state = hidden_state + input_tensor + return hidden_state + + +class CvtLayer(nn.Module): + """ + CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps). + """ + + def __init__( + self, + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + mlp_ratio, + drop_path_rate, + with_cls_token=True, + ): + super().__init__() + self.attention = CvtAttention( + num_heads, + embed_dim, + kernel_size, + padding_q, + padding_kv, + stride_q, + stride_kv, + qkv_projection_method, + qkv_bias, + attention_drop_rate, + drop_rate, + with_cls_token, + ) + + self.intermediate = CvtIntermediate(embed_dim, mlp_ratio) + self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate) + self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_before = nn.LayerNorm(embed_dim) + self.layernorm_after = nn.LayerNorm(embed_dim) + + def forward(self, hidden_state, height, width): + self_attention_output = self.attention( + self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention + height, + width, + ) + attention_output = self_attention_output + attention_output = self.drop_path(attention_output) + + # first residual connection + hidden_state = attention_output + hidden_state + + # in Cvt, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_state) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_state) + layer_output = self.drop_path(layer_output) + return layer_output + + +class CvtStage(nn.Module): + def __init__(self, config, stage): + super().__init__() + self.config = config + self.stage = stage + if self.config.cls_token[self.stage]: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.embed_dim[-1])) + + self.embedding = CvtEmbeddings( + patch_size=config.patch_sizes[self.stage], + stride=config.patch_stride[self.stage], + num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1], + embed_dim=config.embed_dim[self.stage], + padding=config.patch_padding[self.stage], + dropout_rate=config.drop_rate[self.stage], + ) + + drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])] + + self.layers = nn.Sequential( + *[ + CvtLayer( + num_heads=config.num_heads[self.stage], + embed_dim=config.embed_dim[self.stage], + kernel_size=config.kernel_qkv[self.stage], + padding_q=config.padding_q[self.stage], + padding_kv=config.padding_kv[self.stage], + stride_kv=config.stride_kv[self.stage], + stride_q=config.stride_q[self.stage], + qkv_projection_method=config.qkv_projection_method[self.stage], + qkv_bias=config.qkv_bias[self.stage], + attention_drop_rate=config.attention_drop_rate[self.stage], + drop_rate=config.drop_rate[self.stage], + drop_path_rate=drop_path_rates[self.stage], + mlp_ratio=config.mlp_ratio[self.stage], + with_cls_token=config.cls_token[self.stage], + ) + for _ in range(config.depth[self.stage]) + ] + ) + + def forward(self, hidden_state): + cls_token = None + hidden_state = self.embedding(hidden_state) + batch_size, num_channels, height, width = hidden_state.shape + # rearrange b c h w -> b (h w) c" + hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1) + if self.config.cls_token[self.stage]: + cls_token = self.cls_token.expand(batch_size, -1, -1) + hidden_state = torch.cat((cls_token, hidden_state), dim=1) + + for layer in self.layers: + layer_outputs = layer(hidden_state, height, width) + hidden_state = layer_outputs + + if self.config.cls_token[self.stage]: + cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1) + hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width) + return hidden_state, cls_token + + +class CvtEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.stages = nn.ModuleList([]) + for stage_idx in range(len(config.depth)): + self.stages.append(CvtStage(config, stage_idx)) + + def forward(self, pixel_values, output_hidden_states=False, return_dict=True): + all_hidden_states = () if output_hidden_states else None + hidden_state = pixel_values + + cls_token = None + for _, (stage_module) in enumerate(self.stages): + hidden_state, cls_token = stage_module(hidden_state) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None) + + return BaseModelOutputWithCLSToken( + last_hidden_state=hidden_state, + cls_token_value=cls_token, + hidden_states=all_hidden_states, + ) + + +class CvtPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = CvtConfig + base_model_prefix = "cvt" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +CVT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`CvtConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +CVT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`CvtFeatureExtractor`]. See + [`CvtFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.", + CVT_START_DOCSTRING, +) +class CvtModel(CvtPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + self.encoder = CvtEncoder(config) + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithCLSToken, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithCLSToken( + last_hidden_state=sequence_output, + cls_token_value=encoder_outputs.cls_token_value, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + CVT_START_DOCSTRING, +) +class CvtForImageClassification(CvtPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.cvt = CvtModel(config, add_pooling_layer=False) + self.layernorm = nn.LayerNorm(config.embed_dim[-1]) + # Classifier head + self.classifier = ( + nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values=None, + labels=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.cvt( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + cls_token = outputs[1] + if self.config.cls_token[-1]: + sequence_output = self.layernorm(cls_token) + else: + batch_size, num_channels, height, width = sequence_output.shape + # rearrange "b c h w -> b (h w) c" + sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1) + sequence_output = self.layernorm(sequence_output) + + sequence_output_mean = sequence_output.mean(dim=1) + logits = self.classifier(sequence_output_mean) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index fc8f7448d4fff..8f7e291beac6a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1216,6 +1216,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +CVT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class CvtForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class CvtModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class CvtPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/cvt/__init__.py b/tests/models/cvt/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/cvt/test_modeling_cvt.py b/tests/models/cvt/test_modeling_cvt.py new file mode 100644 index 0000000000000..3791c75e8c9ec --- /dev/null +++ b/tests/models/cvt/test_modeling_cvt.py @@ -0,0 +1,278 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch CvT model. """ + + +import inspect +import unittest +from math import floor + +from transformers import CvtConfig +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import CvtForImageClassification, CvtModel + from transformers.models.cvt.modeling_cvt import CVT_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class CvtConfigTester(ConfigTester): + def create_and_test_config_common_properties(self): + config = self.config_class(**self.inputs_dict) + self.parent.assertTrue(hasattr(config, "embed_dim")) + self.parent.assertTrue(hasattr(config, "num_heads")) + + +class CvtModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=64, + num_channels=3, + embed_dim=[16, 48, 96], + num_heads=[1, 3, 6], + depth=[1, 2, 10], + patch_sizes=[7, 3, 3], + patch_stride=[4, 2, 2], + patch_padding=[2, 1, 1], + stride_kv=[2, 2, 2], + cls_token=[False, False, True], + attention_drop_rate=[0.0, 0.0, 0.0], + initializer_range=0.02, + layer_norm_eps=1e-12, + is_training=True, + use_labels=True, + num_labels=2, # Check + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_sizes = patch_sizes + self.patch_stride = patch_stride + self.patch_padding = patch_padding + self.is_training = is_training + self.use_labels = use_labels + self.num_labels = num_labels + self.num_channels = num_channels + self.embed_dim = embed_dim + self.num_heads = num_heads + self.stride_kv = stride_kv + self.depth = depth + self.cls_token = cls_token + self.attention_drop_rate = attention_drop_rate + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + return config, pixel_values, labels + + def get_config(self): + return CvtConfig( + image_size=self.image_size, + num_labels=self.num_labels, + num_channels=self.num_channels, + embed_dim=self.embed_dim, + num_heads=self.num_heads, + patch_sizes=self.patch_sizes, + patch_padding=self.patch_padding, + patch_stride=self.patch_stride, + stride_kv=self.stride_kv, + depth=self.depth, + cls_token=self.cls_token, + attention_drop_rate=self.attention_drop_rate, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = CvtModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + image_size = (self.image_size, self.image_size) + height, width = image_size[0], image_size[1] + for i in range(len(self.depth)): + height = floor(((height + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1) + width = floor(((width + 2 * self.patch_padding[i] - self.patch_sizes[i]) / self.patch_stride[i]) + 1) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dim[-1], height, width)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = CvtForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class CvtModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as Cvt does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else () + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + + def setUp(self): + self.model_tester = CvtModelTester(self) + self.config_tester = ConfigTester(self, config_class=CvtConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="Cvt does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Cvt does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = len(self.model_tester.depth) + self.assertEqual(len(hidden_states), expected_num_layers) + + # verify the first hidden states (first block) + self.assertListEqual( + list(hidden_states[0].shape[-3:]), + [ + self.model_tester.embed_dim[0], + self.model_tester.image_size // 4, + self.model_tester.image_size // 4, + ], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in CVT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = CvtModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class CvtModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return AutoFeatureExtractor.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + + @slow + def test_inference_image_classification_head(self): + model = CvtForImageClassification.from_pretrained(CVT_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([0.9285, 0.9015, -0.3150]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 3393dd49df1d2..45a9eae973486 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -21,6 +21,7 @@ src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/convnext/modeling_convnext.py src/transformers/models/ctrl/modeling_ctrl.py +src/transformers/models/cvt/modeling_cvt.py src/transformers/models/data2vec/modeling_data2vec_audio.py src/transformers/models/data2vec/modeling_data2vec_vision.py src/transformers/models/deit/modeling_deit.py