diff --git a/README.md b/README.md index 18d94bbc3428c..0867d15efdf89 100644 --- a/README.md +++ b/README.md @@ -303,6 +303,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[RegNet](https://huggingface.co/docs/transformers/main/model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[ResNet](https://huggingface.co/docs/transformers/main/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. diff --git a/README_ko.md b/README_ko.md index 5d813b0cc76db..3f442bcd10a16 100644 --- a/README_ko.md +++ b/README_ko.md @@ -281,6 +281,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. +1. **[RegNet](https://huggingface.co/docs/transformers/main/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. 1. **[ResNet](https://huggingface.co/docs/transformers/main/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. diff --git a/README_zh-hans.md b/README_zh-hans.md index 5570335d49f9b..b7733e3b3c175 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -305,6 +305,7 @@ conda install -c huggingface transformers 1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (来自 NVIDIA) 伴随论文 [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) 由 Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius 发布。 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。 +1. **[RegNet](https://huggingface.co/docs/transformers/main/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。 1. **[ResNet](https://huggingface.co/docs/transformers/main/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 88ba012d5a89c..b15096380d128 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -317,6 +317,7 @@ conda install -c huggingface transformers 1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius. 1. **[REALM](https://huggingface.co/docs/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. +1. **[RegNet](https://huggingface.co/docs/transformers/main/model_doc/regnet)** (from META Research) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. 1. **[ResNet](https://huggingface.co/docs/transformers/main/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 69717477e1f23..8f51c9d4e31a1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -296,6 +296,8 @@ title: Reformer - local: model_doc/rembert title: RemBERT + - local: model_doc/regnet + title: RegNet - local: model_doc/resnet title: ResNet - local: model_doc/retribert diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 281add6e5ef45..86353fa94191c 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -124,6 +124,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[REALM](model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[RegNet](model_doc/regnet)** (from META Platforms) released with the paper [Designing Network Design Space](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. 1. **[ResNet](model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. @@ -234,6 +235,7 @@ Flax), PyTorch, and/or TensorFlow. | RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | +| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/regnet.mdx b/docs/source/en/model_doc/regnet.mdx new file mode 100644 index 0000000000000..666a9ee39675d --- /dev/null +++ b/docs/source/en/model_doc/regnet.mdx @@ -0,0 +1,48 @@ + + +# RegNet + +## Overview + +The RegNet model was proposed in [Designing Network Design Spaces](https://arxiv.org/abs/2003.13678) by Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár. + +The authors design search spaces to perform Neural Architecture Search (NAS). They first start from a high dimensional search space and iteratively reduce the search space by empirically applying constraints based on the best-performing models sampled by the current search space. + +The abstract from the paper is the following: + +*In this work, we present a new network design paradigm. Our goal is to help advance the understanding of network design and discover design principles that generalize across settings. Instead of focusing on designing individual network instances, we design network design spaces that parametrize populations of networks. The overall process is analogous to classic manual design of networks, but elevated to the design space level. Using our methodology we explore the structure aspect of network design and arrive at a low-dimensional design space consisting of simple, regular networks that we call RegNet. The core insight of the RegNet parametrization is surprisingly simple: widths and depths of good networks can be explained by a quantized linear function. We analyze the RegNet design space and arrive at interesting findings that do not match the current practice of network design. The RegNet design space provides simple and fast networks that work well across a wide range of flop regimes. Under comparable training settings and flops, the RegNet models outperform the popular EfficientNet models while being up to 5x faster on GPUs.* + +Tips: + +- One can use [`AutoFeatureExtractor`] to prepare images for the model. +- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer) + +This model was contributed by [Francesco](https://huggingface.co/Francesco). +The original code can be found [here](https://github.com/facebookresearch/pycls). + + +## RegNetConfig + +[[autodoc]] RegNetConfig + + +## RegNetModel + +[[autodoc]] RegNetModel + - forward + + +## RegNetForImageClassification + +[[autodoc]] RegNetForImageClassification + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 650009149116b..f9ab3a153733c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -242,6 +242,7 @@ "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], "models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"], "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], + "models.regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"], "models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"], "models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"], "models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"], @@ -1342,6 +1343,14 @@ "ReformerPreTrainedModel", ] ) + _import_structure["models.regnet"].extend( + [ + "REGNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "RegNetForImageClassification", + "RegNetModel", + "RegNetPreTrainedModel", + ] + ) _import_structure["models.rembert"].extend( [ "REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2612,6 +2621,7 @@ from .models.rag import RagConfig, RagRetriever, RagTokenizer from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig + from .models.regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer @@ -3537,6 +3547,12 @@ ReformerModelWithLMHead, ReformerPreTrainedModel, ) + from .models.regnet import ( + REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, + RegNetForImageClassification, + RegNetModel, + RegNetPreTrainedModel, + ) from .models.rembert import ( REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RemBertForCausalLM, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7045f18c55615..d004742de4008 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -96,6 +96,7 @@ rag, realm, reformer, + regnet, rembert, resnet, retribert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 84e23f106e5ba..324b977685b19 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -38,6 +38,7 @@ ("convnext", "ConvNextConfig"), ("van", "VanConfig"), ("resnet", "ResNetConfig"), + ("regnet", "RegNetConfig"), ("yoso", "YosoConfig"), ("swin", "SwinConfig"), ("vilt", "ViltConfig"), @@ -142,6 +143,7 @@ ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("van", "VAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("regnet", "REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -234,6 +236,7 @@ ("convnext", "ConvNext"), ("van", "VAN"), ("resnet", "ResNet"), + ("regnet", "RegNet"), ("yoso", "YOSO"), ("swin", "Swin"), ("vilt", "ViLT"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index dc83fc133fad9..dad7e165e8d72 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -54,6 +54,7 @@ ("convnext", "ConvNextFeatureExtractor"), ("van", "ConvNextFeatureExtractor"), ("resnet", "ConvNextFeatureExtractor"), + ("regnet", "ConvNextFeatureExtractor"), ("poolformer", "PoolFormerFeatureExtractor"), ("maskformer", "MaskFormerFeatureExtractor"), ] diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b0cfb47672491..0814d10cfdda4 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -38,6 +38,7 @@ ("convnext", "ConvNextModel"), ("van", "VanModel"), ("resnet", "ResNetModel"), + ("regnet", "RegNetModel"), ("yoso", "YosoModel"), ("swin", "SwinModel"), ("vilt", "ViltModel"), @@ -303,6 +304,7 @@ ("convnext", "ConvNextForImageClassification"), ("van", "VanForImageClassification"), ("resnet", "ResNetForImageClassification"), + ("regnet", "RegNetForImageClassification"), ("poolformer", "PoolFormerForImageClassification"), ] ) diff --git a/src/transformers/models/regnet/__init__.py b/src/transformers/models/regnet/__init__.py new file mode 100644 index 0000000000000..185ead37b640e --- /dev/null +++ b/src/transformers/models/regnet/__init__.py @@ -0,0 +1,52 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...file_utils import _LazyModule, is_torch_available + + +_import_structure = { + "configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"], +} + +if is_torch_available(): + _import_structure["modeling_regnet"] = [ + "REGNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "RegNetForImageClassification", + "RegNetModel", + "RegNetPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig + + if is_torch_available(): + from .modeling_regnet import ( + REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, + RegNetForImageClassification, + RegNetModel, + RegNetPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/regnet/configuration_regnet.py b/src/transformers/models/regnet/configuration_regnet.py new file mode 100644 index 0000000000000..9bfe35ec9b3d4 --- /dev/null +++ b/src/transformers/models/regnet/configuration_regnet.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RegNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "regnety-40": "https://huggingface.co/zuppif/regnety-040/blob/main/config.json", +} + + +class RegNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RegNetModel`]. It is used to instantiate a RegNet + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [facebook/regnet-y-40](https://huggingface.co/facebook/regnet-y-40) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"y"`): + The layer to use, it can be either `"x" or `"y"`. An `x` layer is a ResNet's BottleNeck layer with + `reduction` fixed to `1`. While a `y` layer is a `x` but with squeeze and excitation. Please refer to the + paper for a detailed explanation of how these layers were constructed. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + + Example: + ```python + >>> from transformers import RegNetConfig, RegNetModel + + >>> # Initializing a RegNet regnet-y-40 style configuration + >>> configuration = RegNetConfig() + >>> # Initializing a model from the regnet-y-40 style configuration + >>> model = RegNetModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = "regnet" + layer_types = ["x", "y"] + + def __init__( + self, + num_channels=3, + embedding_size=32, + hidden_sizes=[128, 192, 512, 1088], + depths=[2, 6, 12, 2], + groups_width=64, + layer_type="y", + hidden_act="relu", + **kwargs + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.groups_width = groups_width + self.layer_type = layer_type + self.hidden_act = hidden_act + # always downsample in the first stage + self.downsample_in_first_stage = True diff --git a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py new file mode 100644 index 0000000000000..8024ef6792011 --- /dev/null +++ b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py @@ -0,0 +1,301 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RegNet 10B checkpoints vissl.""" +# You need to install a specific version of classy vision +# pip install git+https://github.com/FrancescoSaverioZuppichini/ClassyVision.git@convert_weights + +import argparse +import json +import os +import re +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from pprint import pprint +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from classy_vision.models.regnet import RegNet, RegNetParams +from huggingface_hub import cached_download, hf_hub_url +from transformers import AutoFeatureExtractor, RegNetConfig, RegNetForImageClassification, RegNetModel +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from vissl.models.model_helpers import get_trunk_forward_outputs + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + name2module: Dict[str, nn.Module] = field(default_factory=OrderedDict) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor, name: str): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + self.name2module[name] = m + + def __call__(self, x: Tensor): + for name, m in self.module.named_modules(): + self.handles.append(m.register_forward_hook(partial(self._forward_hook, name=name))) + self.module(x) + list(map(lambda x: x.remove(), self.handles)) + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return {k: v for k, v in self.name2module.items() if len(list(v.state_dict().keys())) > 0} + + +class FakeRegNetVisslWrapper(nn.Module): + """ + Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file. + """ + + def __init__(self, model: nn.Module): + super().__init__() + + feature_blocks: List[Tuple[str, nn.Module]] = [] + # - get the stem + feature_blocks.append(("conv1", model.stem)) + # - get all the feature blocks + for k, v in model.trunk_output.named_children(): + assert k.startswith("block"), f"Unexpected layer name {k}" + block_index = len(feature_blocks) + 1 + feature_blocks.append((f"res{block_index}", v)) + + self._feature_blocks = nn.ModuleDict(feature_blocks) + + def forward(self, x: Tensor): + return get_trunk_forward_outputs( + x, + out_feat_keys=None, + feature_blocks=self._feature_blocks, + ) + + +class FakeRegNetParams(RegNetParams): + """ + Used to instantiace a RegNet model from classy vision with the same depth as the 10B one but with super small + parameters, so we can trace it in memory. + """ + + def get_expanded_params(self): + return [(8, 2, 2, 8, 1.0), (8, 2, 7, 8, 1.0), (8, 2, 17, 8, 1.0), (8, 2, 1, 8, 1.0)] + + +def get_from_to_our_keys(model_name: str) -> Dict[str, str]: + """ + Returns a dictionary that maps from original model's key -> our implementation's keys + """ + + # create our model (with small weights) + our_config = RegNetConfig(depths=[2, 7, 17, 1], hidden_sizes=[8, 8, 8, 8], groups_width=8) + if "in1k" in model_name: + our_model = RegNetForImageClassification(our_config) + else: + our_model = RegNetModel(our_config) + # create from model (with small weights) + from_model = FakeRegNetVisslWrapper( + RegNet(FakeRegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ) + + with torch.no_grad(): + from_model = from_model.eval() + our_model = our_model.eval() + + x = torch.randn((1, 3, 32, 32)) + # trace both + dest_tracker = Tracker(our_model) + dest_traced = dest_tracker(x).parametrized + + pprint(dest_tracker.name2module) + src_tracker = Tracker(from_model) + src_traced = src_tracker(x).parametrized + + # convert the keys -> module dict to keys -> params + def to_params_dict(dict_with_modules): + params_dict = OrderedDict() + for name, module in dict_with_modules.items(): + for param_name, param in module.state_dict().items(): + params_dict[f"{name}.{param_name}"] = param + return params_dict + + from_to_ours_keys = {} + + src_state_dict = to_params_dict(src_traced) + dst_state_dict = to_params_dict(dest_traced) + + for (src_key, src_param), (dest_key, dest_param) in zip(src_state_dict.items(), dst_state_dict.items()): + from_to_ours_keys[src_key] = dest_key + logger.info(f"{src_key} -> {dest_key}") + # if "in1k" was in the model_name it means it must have a classification head (was finetuned) + if "in1k" in model_name: + from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight" + from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias" + + return from_to_ours_keys + + +def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + + repo_id = "datasets/huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "regnet-y-10b-seer": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + # finetuned on imagenet + "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + } + + # add seer weights logic + def load_using_classy_vision(checkpoint_url: str) -> Tuple[Dict, Dict]: + files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu") + # check if we have a head, if yes add it + model_state_dict = files["classy_state_dict"]["base_model"]["model"] + return model_state_dict["trunk"], model_state_dict["heads"] + + names_to_from_model = { + "regnet-y-10b-seer": partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", + ), + "regnet-y-10b-seer-in1k": partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", + ), + } + + from_to_ours_keys = get_from_to_our_keys(model_name) + + if not (save_directory / f"{model_name}.pth").exists(): + logger.info("Loading original state_dict.") + from_state_dict_trunk, from_state_dict_head = names_to_from_model[model_name]() + from_state_dict = from_state_dict_trunk + if "in1k" in model_name: + # add the head + from_state_dict = {**from_state_dict_trunk, **from_state_dict_head} + logger.info("Done!") + + converted_state_dict = {} + + not_used_keys = list(from_state_dict.keys()) + regex = r"\.block.-part." + # this is "interesting", so the original checkpoints have `block[0,1]-part` in each key name, we remove it + for key in from_state_dict.keys(): + # remove the weird "block[0,1]-part" from the key + src_key = re.sub(regex, "", key) + # now src_key from the model checkpoints is the one we got from the original model after tracing, so use it to get the correct destination key + dest_key = from_to_ours_keys[src_key] + # store the parameter with our key + converted_state_dict[dest_key] = from_state_dict[key] + not_used_keys.remove(key) + # check that all keys have been updated + assert len(not_used_keys) == 0, f"Some keys where not used {','.join(not_used_keys)}" + + logger.info(f"The following keys were not used: {','.join(not_used_keys)}") + + # save our state dict to disk + torch.save(converted_state_dict, save_directory / f"{model_name}.pth") + + del converted_state_dict + else: + logger.info("The state_dict was already stored on disk.") + if push_to_hub: + logger.info(f"Token is {os.environ['HF_TOKEN']}") + logger.info("Loading our model.") + # create our model + our_config = names_to_config[model_name] + our_model_func = RegNetModel + if "in1k" in model_name: + our_model_func = RegNetForImageClassification + our_model = our_model_func(our_config) + # place our model to the meta device (so remove all the weights) + our_model.to(torch.device("meta")) + logger.info("Loading state_dict in our model.") + # load state dict + state_dict_keys = our_model.state_dict().keys() + PreTrainedModel._load_pretrained_model_low_mem( + our_model, state_dict_keys, [save_directory / f"{model_name}.pth"] + ) + logger.info("Finally, pushing!") + # push it to hub + our_model.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add model", + output_dir=save_directory / model_name, + ) + size = 384 + # we can use the convnext one + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size) + feature_extractor.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add feature extractor", + output_dir=save_directory / model_name, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and feature extractor to the hub.", + ) + + args = parser.parse_args() + + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/src/transformers/models/regnet/convert_regnet_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_to_pytorch.py new file mode 100644 index 0000000000000..96e4ab700ab5e --- /dev/null +++ b/src/transformers/models/regnet/convert_regnet_to_pytorch.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RegNet checkpoints from timm and vissl.""" + + +import argparse +import json +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +import timm +from classy_vision.models.regnet import RegNet, RegNetParams, RegNetY32gf, RegNetY64gf, RegNetY128gf +from huggingface_hub import cached_download, hf_hub_url +from transformers import AutoFeatureExtractor, RegNetConfig, RegNetForImageClassification, RegNetModel +from transformers.utils import logging +from vissl.models.model_helpers import get_trunk_forward_outputs + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + + def __call__(self, x: Tensor): + for m in self.module.modules(): + self.handles.append(m.register_forward_hook(self._forward_hook)) + self.module(x) + list(map(lambda x: x.remove(), self.handles)) + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced)) + + +@dataclass +class ModuleTransfer: + src: nn.Module + dest: nn.Module + verbose: int = 1 + src_skip: List = field(default_factory=list) + dest_skip: List = field(default_factory=list) + raise_if_mismatch: bool = True + + def __call__(self, x: Tensor): + """ + Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the + hood we tracked all the operations in both modules. + """ + dest_traced = Tracker(self.dest)(x).parametrized + src_traced = Tracker(self.src)(x).parametrized + + src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced)) + dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced)) + + if len(dest_traced) != len(src_traced) and self.raise_if_mismatch: + raise Exception( + f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}." + ) + + for dest_m, src_m in zip(dest_traced, src_traced): + dest_m.load_state_dict(src_m.state_dict()) + if self.verbose == 1: + print(f"Transfered from={src_m} to={dest_m}") + + +class FakeRegNetVisslWrapper(nn.Module): + """ + Fake wrapper for RegNet that mimics what vissl does without the need to pass a config file. + """ + + def __init__(self, model: nn.Module): + super().__init__() + + feature_blocks: List[Tuple[str, nn.Module]] = [] + # - get the stem + feature_blocks.append(("conv1", model.stem)) + # - get all the feature blocks + for k, v in model.trunk_output.named_children(): + assert k.startswith("block"), f"Unexpected layer name {k}" + block_index = len(feature_blocks) + 1 + feature_blocks.append((f"res{block_index}", v)) + + self._feature_blocks = nn.ModuleDict(feature_blocks) + + def forward(self, x: Tensor): + return get_trunk_forward_outputs( + x, + out_feat_keys=None, + feature_blocks=self._feature_blocks, + ) + + +class NameToFromModelFuncMap(dict): + """ + A Dictionary with some additional logic to return a function that creates the correct original model. + """ + + def convert_name_to_timm(self, x: str) -> str: + x_split = x.split("-") + return x_split[0] + x_split[1] + "_" + "".join(x_split[2:]) + + def __getitem__(self, x: str) -> Callable[[], Tuple[nn.Module, Dict]]: + # default to timm! + if x not in self: + x = self.convert_name_to_timm(x) + val = partial(lambda: (timm.create_model(x, pretrained=True).eval(), None)) + + else: + val = super().__getitem__(x) + + return val + + +class NameToOurModelFuncMap(dict): + """ + A Dictionary with some additional logic to return the correct hugging face RegNet class reference. + """ + + def __getitem__(self, x: str) -> Callable[[], nn.Module]: + if "seer" in x and "in1k" not in x: + val = RegNetModel + else: + val = RegNetForImageClassification + return val + + +def manually_copy_vissl_head(from_state_dict, to_state_dict, keys: List[Tuple[str, str]]): + for from_key, to_key in keys: + to_state_dict[to_key] = from_state_dict[from_key].clone() + print(f"Copied key={from_key} to={to_key}") + return to_state_dict + + +def convert_weight_and_push( + name: str, + from_model_func: Callable[[], nn.Module], + our_model_func: Callable[[], nn.Module], + config: RegNetConfig, + save_directory: Path, + push_to_hub: bool = True, +): + print(f"Converting {name}...") + with torch.no_grad(): + from_model, from_state_dict = from_model_func() + our_model = our_model_func(config).eval() + module_transfer = ModuleTransfer(src=from_model, dest=our_model, raise_if_mismatch=False) + x = torch.randn((1, 3, 224, 224)) + module_transfer(x) + + if from_state_dict is not None: + keys = [] + # for seer - in1k finetuned we have to manually copy the head + if "seer" in name and "in1k" in name: + keys = [("0.clf.0.weight", "classifier.1.weight"), ("0.clf.0.bias", "classifier.1.bias")] + to_state_dict = manually_copy_vissl_head(from_state_dict, our_model.state_dict(), keys) + our_model.load_state_dict(to_state_dict) + + our_outputs = our_model(x, output_hidden_states=True) + our_output = ( + our_outputs.logits if isinstance(our_model, RegNetForImageClassification) else our_outputs.last_hidden_state + ) + + from_output = from_model(x) + from_output = from_output[-1] if type(from_output) is list else from_output + + # now since I don't want to use any config files, vissl seer model doesn't actually have an head, so let's just check the last hidden state + if "seer" in name and "in1k" in name: + our_output = our_outputs.hidden_states[-1] + + assert torch.allclose(from_output, our_output), "The model logits don't match the original one." + + if push_to_hub: + our_model.push_to_hub( + repo_path_or_name=save_directory / name, + commit_message="Add model", + use_temp_dir=True, + ) + + size = 224 if "seer" not in name else 384 + # we can use the convnext one + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k", size=size) + feature_extractor.push_to_hub( + repo_path_or_name=save_directory / name, + commit_message="Add feature extractor", + use_temp_dir=True, + ) + + print(f"Pushed {name}") + + +def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + expected_shape = (1, num_labels) + + repo_id = "datasets/huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(RegNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "regnet-x-002": ImageNetPreTrainedConfig( + depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8, layer_type="x" + ), + "regnet-x-004": ImageNetPreTrainedConfig( + depths=[1, 2, 7, 12], hidden_sizes=[32, 64, 160, 384], groups_width=16, layer_type="x" + ), + "regnet-x-006": ImageNetPreTrainedConfig( + depths=[1, 3, 5, 7], hidden_sizes=[48, 96, 240, 528], groups_width=24, layer_type="x" + ), + "regnet-x-008": ImageNetPreTrainedConfig( + depths=[1, 3, 7, 5], hidden_sizes=[64, 128, 288, 672], groups_width=16, layer_type="x" + ), + "regnet-x-016": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 2], hidden_sizes=[72, 168, 408, 912], groups_width=24, layer_type="x" + ), + "regnet-x-032": ImageNetPreTrainedConfig( + depths=[2, 6, 15, 2], hidden_sizes=[96, 192, 432, 1008], groups_width=48, layer_type="x" + ), + "regnet-x-040": ImageNetPreTrainedConfig( + depths=[2, 5, 14, 2], hidden_sizes=[80, 240, 560, 1360], groups_width=40, layer_type="x" + ), + "regnet-x-064": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 1], hidden_sizes=[168, 392, 784, 1624], groups_width=56, layer_type="x" + ), + "regnet-x-080": ImageNetPreTrainedConfig( + depths=[2, 5, 15, 1], hidden_sizes=[80, 240, 720, 1920], groups_width=120, layer_type="x" + ), + "regnet-x-120": ImageNetPreTrainedConfig( + depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112, layer_type="x" + ), + "regnet-x-160": ImageNetPreTrainedConfig( + depths=[2, 6, 13, 1], hidden_sizes=[256, 512, 896, 2048], groups_width=128, layer_type="x" + ), + "regnet-x-320": ImageNetPreTrainedConfig( + depths=[2, 7, 13, 1], hidden_sizes=[336, 672, 1344, 2520], groups_width=168, layer_type="x" + ), + # y variant + "regnet-y-002": ImageNetPreTrainedConfig(depths=[1, 1, 4, 7], hidden_sizes=[24, 56, 152, 368], groups_width=8), + "regnet-y-004": ImageNetPreTrainedConfig( + depths=[1, 3, 6, 6], hidden_sizes=[48, 104, 208, 440], groups_width=8 + ), + "regnet-y-006": ImageNetPreTrainedConfig( + depths=[1, 3, 7, 4], hidden_sizes=[48, 112, 256, 608], groups_width=16 + ), + "regnet-y-008": ImageNetPreTrainedConfig( + depths=[1, 3, 8, 2], hidden_sizes=[64, 128, 320, 768], groups_width=16 + ), + "regnet-y-016": ImageNetPreTrainedConfig( + depths=[2, 6, 17, 2], hidden_sizes=[48, 120, 336, 888], groups_width=24 + ), + "regnet-y-032": ImageNetPreTrainedConfig( + depths=[2, 5, 13, 1], hidden_sizes=[72, 216, 576, 1512], groups_width=24 + ), + "regnet-y-040": ImageNetPreTrainedConfig( + depths=[2, 6, 12, 2], hidden_sizes=[128, 192, 512, 1088], groups_width=64 + ), + "regnet-y-064": ImageNetPreTrainedConfig( + depths=[2, 7, 14, 2], hidden_sizes=[144, 288, 576, 1296], groups_width=72 + ), + "regnet-y-080": ImageNetPreTrainedConfig( + depths=[2, 4, 10, 1], hidden_sizes=[168, 448, 896, 2016], groups_width=56 + ), + "regnet-y-120": ImageNetPreTrainedConfig( + depths=[2, 5, 11, 1], hidden_sizes=[224, 448, 896, 2240], groups_width=112 + ), + "regnet-y-160": ImageNetPreTrainedConfig( + depths=[2, 4, 11, 1], hidden_sizes=[224, 448, 1232, 3024], groups_width=112 + ), + "regnet-y-320": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232 + ), + # models created by SEER -> https://arxiv.org/abs/2202.08360 + "regnet-y-320-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232), + "regnet-y-640-seer": RegNetConfig(depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328), + "regnet-y-1280-seer": RegNetConfig( + depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264 + ), + "regnet-y-2560-seer": RegNetConfig( + depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640 + ), + "regnet-y-10b-seer": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + # finetuned on imagenet + "regnet-y-320-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[232, 696, 1392, 3712], groups_width=232 + ), + "regnet-y-640-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 5, 12, 1], hidden_sizes=[328, 984, 1968, 4920], groups_width=328 + ), + "regnet-y-1280-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[528, 1056, 2904, 7392], groups_width=264 + ), + "regnet-y-2560-seer-in1k": ImageNetPreTrainedConfig( + depths=[3, 7, 16, 1], hidden_sizes=[640, 1696, 2544, 5088], groups_width=640 + ), + "regnet-y-10b-seer-in1k": ImageNetPreTrainedConfig( + depths=[2, 7, 17, 1], hidden_sizes=[2020, 4040, 11110, 28280], groups_width=1010 + ), + } + + names_to_ours_model_map = NameToOurModelFuncMap() + names_to_from_model_map = NameToFromModelFuncMap() + # add seer weights logic + + def load_using_classy_vision(checkpoint_url: str, model_func: Callable[[], nn.Module]) -> Tuple[nn.Module, Dict]: + files = torch.hub.load_state_dict_from_url(checkpoint_url, model_dir=str(save_directory), map_location="cpu") + model = model_func() + # check if we have a head, if yes add it + model_state_dict = files["classy_state_dict"]["base_model"]["model"] + state_dict = model_state_dict["trunk"] + model.load_state_dict(state_dict) + return model.eval(), model_state_dict["heads"] + + # pretrained + names_to_from_model_map["regnet-y-320-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet32d/seer_regnet32gf_model_iteration244000.torch", + lambda: FakeRegNetVisslWrapper(RegNetY32gf()), + ) + + names_to_from_model_map["regnet-y-640-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet64/seer_regnet64gf_model_final_checkpoint_phase0.torch", + lambda: FakeRegNetVisslWrapper(RegNetY64gf()), + ) + + names_to_from_model_map["regnet-y-1280-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_ig1b_regnet128Gf_cnstant_bs32_node16_sinkhorn10_proto16k_syncBN64_warmup8k/model_final_checkpoint_phase0.torch", + lambda: FakeRegNetVisslWrapper(RegNetY128gf()), + ) + + names_to_from_model_map["regnet-y-10b-seer"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_regnet10B/model_iteration124500_conso.torch", + lambda: FakeRegNetVisslWrapper( + RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ), + ) + + # IN1K finetuned + names_to_from_model_map["regnet-y-320-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet32_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY32gf()), + ) + + names_to_from_model_map["regnet-y-640-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet64_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY64gf()), + ) + + names_to_from_model_map["regnet-y-1280-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_regnet128_finetuned_in1k_model_final_checkpoint_phase78.torch", + lambda: FakeRegNetVisslWrapper(RegNetY128gf()), + ) + + names_to_from_model_map["regnet-y-10b-seer-in1k"] = partial( + load_using_classy_vision, + "https://dl.fbaipublicfiles.com/vissl/model_zoo/seer_finetuned/seer_10b_finetuned_in1k_model_phase28_conso.torch", + lambda: FakeRegNetVisslWrapper( + RegNet(RegNetParams(depth=27, group_width=1010, w_0=1744, w_a=620.83, w_m=2.52)) + ), + ) + + if model_name: + convert_weight_and_push( + model_name, + names_to_from_model_map[model_name], + names_to_ours_model_map[model_name], + names_to_config[model_name], + save_directory, + push_to_hub, + ) + else: + for model_name, config in names_to_config.items(): + convert_weight_and_push( + model_name, + names_to_from_model_map[model_name], + names_to_ours_model_map[model_name], + config, + save_directory, + push_to_hub, + ) + return config, expected_shape + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help="The name of the model you wish to convert, it must be one of the supported regnet* architecture, currently: regnetx-*, regnety-*. If `None`, all of them will the converted.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and feature extractor to the hub.", + ) + + args = parser.parse_args() + + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py new file mode 100644 index 0000000000000..0ebd05a25ce15 --- /dev/null +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -0,0 +1,444 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch RegNet model.""" + +from typing import Optional + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_regnet import RegNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "RegNetConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040" +_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040" +_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'" + +REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/regnet-y-040", + # See all regnet models at https://huggingface.co/models?filter=regnet +] + + +class RegNetConvLayer(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + groups: int = 1, + activation: Optional[str] = "relu", + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=groups, + bias=False, + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetEmbeddings(nn.Module): + """ + RegNet Embedddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: RegNetConfig): + super().__init__() + self.embedder = RegNetConvLayer( + config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act + ) + + def forward(self, hidden_state): + hidden_state = self.embedder(hidden_state) + return hidden_state + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet +class RegNetShortCut(nn.Sequential): + """ + RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.normalization = nn.BatchNorm2d(out_channels) + + +class RegNetSELayer(nn.Module): + """ + Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507). + """ + + def __init__(self, in_channels: int, reduced_channels: int): + super().__init__() + + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + self.attention = nn.Sequential( + nn.Conv2d(in_channels, reduced_channels, kernel_size=1), + nn.ReLU(), + nn.Conv2d(reduced_channels, in_channels, kernel_size=1), + nn.Sigmoid(), + ) + + def forward(self, hidden_state): + # b c h w -> b c 1 1 + pooled = self.pooler(hidden_state) + attention = self.attention(pooled) + hidden_state = hidden_state * attention + return hidden_state + + +class RegNetXLayer(nn.Module): + """ + RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetYLayer(nn.Module): + """ + RegNet's Y layer: an X layer with Squeeze and Excitation. + """ + + def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + groups = max(1, out_channels // config.groups_width) + self.shortcut = ( + RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act), + RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act), + RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))), + RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class RegNetStage(nn.Module): + """ + A RegNet stage composed by stacked layers. + """ + + def __init__( + self, + config: RegNetConfig, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + ): + super().__init__() + + layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer + + self.layers = nn.Sequential( + # downsampling is done in the first layer with stride of 2 + layer( + config, + in_channels, + out_channels, + stride=stride, + ), + *[layer(config, out_channels, out_channels) for _ in range(depth - 1)], + ) + + def forward(self, hidden_state): + hidden_state = self.layers(hidden_state) + return hidden_state + + +class RegNetEncoder(nn.Module): + def __init__(self, config: RegNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input + self.stages.append( + RegNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth)) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet +class RegNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RegNetConfig + base_model_prefix = "regnet" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RegNetModel): + module.gradient_checkpointing = value + + +REGNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RegNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REGNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RegNet model outputting raw features without any specific head on top.", + REGNET_START_DOCSTRING, +) +# Copied from transformers.models.resnet.modeling_resnet.ResNetModel with RESNET->REGNET,ResNet->RegNet +class RegNetModel(RegNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedder = RegNetEmbeddings(config) + self.encoder = RegNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BaseModelOutputWithPoolingAndNoAttention: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + REGNET_START_DOCSTRING, +) +# Copied from transformers.models.resnet.modeling_resnet.ResNetForImageClassification with RESNET->REGNET,ResNet->RegNet,resnet->regnet +class RegNetForImageClassification(RegNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.regnet = RegNetModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Tensor = None, + labels: Tensor = None, + output_hidden_states: bool = None, + return_dict: bool = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 00c589d68f7bf..f2f555d7f5196 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -66,12 +66,14 @@ def __init__( class ResNetEmbeddings(nn.Sequential): """ - ResNet Embedddings (stem) composed of a single aggressive convolution. + ResNet Embeddings (stem) composed of a single aggressive convolution. """ - def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"): + def __init__(self, config: ResNetConfig): super().__init__() - self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation) + self.embedder = ResNetConvLayer( + config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act + ) self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -267,7 +269,7 @@ class ResNetModel(ResNetPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config - self.embedder = ResNetEmbeddings(config.num_channels, config.embedding_size, config.hidden_act) + self.embedder = ResNetEmbeddings(config) self.encoder = ResNetEncoder(config) self.pooler = nn.AdaptiveAvgPool2d((1, 1)) # Initialize weights and apply final processing diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 2bd71277b5202..3233b44d896c6 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3350,6 +3350,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class RegNetForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RegNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class RegNetPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/regnet/__init__.py b/tests/regnet/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/regnet/test_modeling_regnet.py b/tests/regnet/test_modeling_regnet.py new file mode 100644 index 0000000000000..331e45296bc3f --- /dev/null +++ b/tests/regnet/test_modeling_regnet.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch RegNet model. """ + + +import inspect +import unittest + +from transformers import RegNetConfig +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import RegNetForImageClassification, RegNetModel + from transformers.models.regnet.modeling_regnet import REGNET_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class RegNetModelTester: + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return RegNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = RegNetModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected last hidden states: B, C, H // 32, W // 32 + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = RegNetForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class RegNetModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (RegNetModel, RegNetForImageClassification) if is_torch_available() else () + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + + def setUp(self): + self.model_tester = RegNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="RegNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="RegNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, module in model.named_modules(): + if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + self.assertTrue( + torch.all(module.weight == 1), + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + self.assertTrue( + torch.all(module.bias == 0), + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + # RegNet's feature maps are of shape (batch_size, num_channels, height, width) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_size // 2, self.model_tester.image_size // 2], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + layers_type = ["basic", "bottleneck"] + for model_class in self.all_model_classes: + for layer_type in layers_type: + config.layer_type = layer_type + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = RegNetModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class RegNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + AutoFeatureExtractor.from_pretrained(REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = RegNetForImageClassification.from_pretrained(REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-0.4180, -1.5051, -3.4836]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))