diff --git a/README.md b/README.md index 1fc2e930f2a96..f727a274bd41d 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. diff --git a/README_ko.md b/README_ko.md index 26f4d6b1f7b18..067403a3c9831 100644 --- a/README_ko.md +++ b/README_ko.md @@ -259,6 +259,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. diff --git a/README_zh-hans.md b/README_zh-hans.md index aac9c77e6e3da..a55cc0948a642 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -283,6 +283,7 @@ conda install -c huggingface transformers 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** 用 [OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。 +1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov 1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 由 Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer 发布。 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 0d52ffa69e872..2dd3281f2a2d3 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -295,6 +295,7 @@ conda install -c huggingface transformers 1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov 1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 214b203cf9a9a..141cc37593000 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -230,6 +230,8 @@ title: LXMERT - local: model_doc/marian title: MarianMT + - local: model_doc/maskformer + title: MaskFormer - local: model_doc/m2m_100 title: M2M100 - local: model_doc/mbart diff --git a/docs/source/index.mdx b/docs/source/index.mdx index f735cd832293c..d66ece29e32b2 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -105,6 +105,7 @@ conversion utilities for the following models. 1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. +1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. 1. **[MBart](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 1. **[MBart-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. 1. **[Megatron-BERT](model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. @@ -209,6 +210,7 @@ Flax), PyTorch, and/or TensorFlow. | LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ | | M2M100 | ✅ | ❌ | ✅ | ❌ | ❌ | | Marian | ✅ | ❌ | ✅ | ✅ | ✅ | +| MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | mBART | ✅ | ✅ | ✅ | ✅ | ✅ | | MegatronBert | ❌ | ❌ | ✅ | ❌ | ❌ | | MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/source/model_doc/maskformer.mdx b/docs/source/model_doc/maskformer.mdx new file mode 100644 index 0000000000000..886b0a587c5ac --- /dev/null +++ b/docs/source/model_doc/maskformer.mdx @@ -0,0 +1,71 @@ + + +# MaskFormer + + + +This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight +breaking changes to fix it in the future. If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title). + + + +## Overview + +The MaskFormer model was proposed in [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. MaskFormer addresses semantic segmentation with a mask classification paradigm instead of performing classic pixel-level classification. + +The abstract from the paper is the following: + +*Modern approaches typically formulate semantic segmentation as a per-pixel classification task, while instance-level segmentation is handled with an alternative mask classification. Our key insight: mask classification is sufficiently general to solve both semantic- and instance-level segmentation tasks in a unified manner using the exact same model, loss, and training procedure. Following this observation, we propose MaskFormer, a simple mask classification model which predicts a set of binary masks, each associated with a single global class label prediction. Overall, the proposed mask classification-based method simplifies the landscape of effective approaches to semantic and panoptic segmentation tasks and shows excellent empirical results. In particular, we observe that MaskFormer outperforms per-pixel classification baselines when the number of classes is large. Our mask classification-based method outperforms both current state-of-the-art semantic (55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models.* + +Tips: +- MaskFormer's Transformer decoder is identical to the decoder of [DETR](detr). During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct number of objects of each class. If you set the parameter `use_auxilary_loss` of [`MaskFormerConfig`] to `True`, then prediction feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters). +- If you want to train the model in a distributed environment across multiple nodes, then one should update the + `get_num_masks` function inside in the `MaskFormerLoss` class of `modeling_maskformer.py`. When training on multiple nodes, this should be + set to the average number of target masks across all nodes, as can be seen in the original implementation [here](https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/criterion.py#L169). +- One can use [`MaskFormerFeatureExtractor`] to prepare images for the model and optional targets for the model. +- To get the final segmentation, depending on the task, you can call [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`]. Both tasks can be solved using [`MaskFormerForInstanceSegmentation`] output, the latter needs an additional `is_thing_map` to know which instances must be merged together.. + +The figure below illustrates the architecture of MaskFormer. Taken from the [original paper](https://arxiv.org/abs/2107.06278). + + + +This model was contributed by [francesco](https://huggingface.co/francesco). The original code can be found [here](https://github.com/facebookresearch/MaskFormer). + +## MaskFormer specific outputs + +[[autodoc]] models.maskformer.modeling_maskformer.MaskFormerModelOutput + +[[autodoc]] models.maskformer.modeling_maskformer.MaskFormerForInstanceSegmentationOutput + +## MaskFormerConfig + +[[autodoc]] MaskFormerConfig + +## MaskFormerFeatureExtractor + +[[autodoc]] MaskFormerFeatureExtractor + - __call__ + - encode_inputs + - post_process_segmentation + - post_process_semantic_segmentation + - post_process_panoptic_segmentation + +## MaskFormerModel + +[[autodoc]] MaskFormerModel + - forward + +## MaskFormerForInstanceSegmentation + +[[autodoc]] MaskFormerForInstanceSegmentation + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b1e8d0c97df80..29bb7243df65e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -247,6 +247,7 @@ "models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], "models.marian": ["MarianConfig"], + "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"], "models.mbart": ["MBartConfig"], "models.mbart50": [], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], @@ -527,6 +528,7 @@ _import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor") _import_structure["models.layoutlmv2"].append("LayoutLMv2Processor") _import_structure["models.layoutxlm"].append("LayoutXLMProcessor") + _import_structure["models.maskformer"].append("MaskFormerFeatureExtractor") _import_structure["models.perceiver"].append("PerceiverFeatureExtractor") _import_structure["models.poolformer"].append("PoolFormerFeatureExtractor") _import_structure["models.segformer"].append("SegformerFeatureExtractor") @@ -1147,6 +1149,14 @@ ] ) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) + _import_structure["models.maskformer"].extend( + [ + "MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "MaskFormerForInstanceSegmentation", + "MaskFormerModel", + "MaskFormerPreTrainedModel", + ] + ) _import_structure["models.mbart"].extend( [ "MBartForCausalLM", @@ -2532,6 +2542,7 @@ from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config from .models.marian import MarianConfig + from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig from .models.mbart import MBartConfig from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig from .models.mmbt import MMBTConfig @@ -2763,6 +2774,7 @@ from .models.imagegpt import ImageGPTFeatureExtractor from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor from .models.layoutxlm import LayoutXLMProcessor + from .models.maskformer import MaskFormerFeatureExtractor from .models.perceiver import PerceiverFeatureExtractor from .models.poolformer import PoolFormerFeatureExtractor from .models.segformer import SegformerFeatureExtractor @@ -3273,6 +3285,12 @@ M2M100PreTrainedModel, ) from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel + from .models.maskformer import ( + MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + MaskFormerForInstanceSegmentation, + MaskFormerModel, + MaskFormerPreTrainedModel, + ) from .models.mbart import ( MBartForCausalLM, MBartForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b7f654516b96c..e3092afba5d46 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -70,6 +70,7 @@ lxmert, m2m_100, marian, + maskformer, mbart, mbart50, megatron_bert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 5c42cad65f2d2..4e2a55f76610a 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -30,6 +30,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("maskformer", "MaskFormerConfig"), ("poolformer", "PoolFormerConfig"), ("convnext", "ConvNextConfig"), ("yoso", "YosoConfig"), @@ -129,6 +130,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here + ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -215,6 +217,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("maskformer", "MaskFormer"), ("poolformer", "PoolFormer"), ("convnext", "ConvNext"), ("yoso", "YOSO"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 25b49b309799a..d1e4727c0f3a9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,6 +28,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("maskformer", "MaskFormerModel"), ("poolformer", "PoolFormerModel"), ("convnext", "ConvNextModel"), ("yoso", "YosoModel"), diff --git a/src/transformers/models/maskformer/__init__.py b/src/transformers/models/maskformer/__init__.py new file mode 100644 index 0000000000000..547383672fc4d --- /dev/null +++ b/src/transformers/models/maskformer/__init__.py @@ -0,0 +1,56 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"] + + +if is_torch_available(): + _import_structure["modeling_maskformer"] = [ + "MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "MaskFormerForInstanceSegmentation", + "MaskFormerModel", + "MaskFormerPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig + + if is_vision_available(): + from .feature_extraction_maskformer import MaskFormerFeatureExtractor + if is_torch_available(): + from .modeling_maskformer import ( + MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + MaskFormerForInstanceSegmentation, + MaskFormerModel, + MaskFormerPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/maskformer/configuration_maskformer.py b/src/transformers/models/maskformer/configuration_maskformer.py new file mode 100644 index 0000000000000..dfaf067e24e4d --- /dev/null +++ b/src/transformers/models/maskformer/configuration_maskformer.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MaskFormer model configuration""" +import copy +from typing import Dict, Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..detr import DetrConfig +from ..swin import SwinConfig + + +MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/maskformer-swin-base-ade": "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json" + # See all MaskFormer models at https://huggingface.co/models?filter=maskformer +} + +logger = logging.get_logger(__name__) + + +class MaskFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MaskFormerModel`]. It is used to instantiate a + MaskFormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + "facebook/maskformer-swin-base-ade" architecture trained on + [ADE20k-150](https://huggingface.co/datasets/scene_parse_150). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Currently, MaskFormer only supports the [Swin Transformer](swin) as backbone. + + Args: + mask_feature_size (`int`, *optional*, defaults to 256): + The masks' features size, this value will also be used to specify the Feature Pyramid Network features' + size. + no_object_weight (`float`, *optional*, defaults to 0.1): + Weight to apply to the null (no object) class. + use_auxiliary_loss(`bool`, *optional*, defaults to `False`): + If `True` [`MaskFormerForInstanceSegmentationOutput`] will contain the auxiliary losses computed using the + logits from each decoder's stage. + backbone_config (`Dict`, *optional*): + The configuration passed to the backbone, if unset, the configuration corresponding to + `swin-base-patch4-window12-384` will be used. + decoder_config (`Dict`, *optional*): + The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` + will be used. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + init_xavier_std (`float`, *optional*, defaults to 1): + The scaling factor used for the Xavier initialization gain in the HM Attention map module. + dice_weight (`float`, *optional*, defaults to 1.0): + The weight for the dice loss. + cross_entropy_weight (`float`, *optional*, defaults to 1.0): + The weight for the cross entropy loss. + mask_weight (`float`, *optional*, defaults to 20.0): + The weight for the mask loss. + + Raises: + `ValueError`: + Raised if the backbone model type selected is not in `["swin"]` or the decoder model type selected is not + in `["detr"]` + + Examples: + + ```python + >>> from transformers import MaskFormerConfig, MaskFormerModel + + >>> # Initializing a MaskFormer facebook/maskformer-swin-base-ade configuration + >>> configuration = MaskFormerConfig() + + >>> # Initializing a model from the facebook/maskformer-swin-base-ade style configuration + >>> model = MaskFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + + """ + model_type = "maskformer" + attribute_map = {"hidden_size": "mask_feature_size"} + backbones_supported = ["swin"] + decoders_supported = ["detr"] + + def __init__( + self, + fpn_feature_size: int = 256, + mask_feature_size: int = 256, + no_object_weight: float = 0.1, + use_auxiliary_loss: bool = False, + backbone_config: Optional[Dict] = None, + decoder_config: Optional[Dict] = None, + init_std: float = 0.02, + init_xavier_std: float = 1.0, + dice_weight: float = 1.0, + cross_entropy_weight: float = 1.0, + mask_weight: float = 20.0, + **kwargs, + ): + if backbone_config is None: + # fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k + backbone_config = SwinConfig( + image_size=384, + in_channels=3, + patch_size=4, + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + drop_path_rate=0.3, + ) + else: + backbone_model_type = backbone_config.pop("model_type") + if backbone_model_type not in self.backbones_supported: + raise ValueError( + f"Backbone {backbone_model_type} not supported, please use one of {','.join(self.backbones_supported)}" + ) + backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config) + + if decoder_config is None: + # fall back to https://huggingface.co/facebook/detr-resnet-50 + decoder_config = DetrConfig() + else: + decoder_type = decoder_config.pop("model_type") + if decoder_type not in self.decoders_supported: + raise ValueError( + f"Transformer Decoder {decoder_type} not supported, please use one of {','.join(self.decoders_supported)}" + ) + decoder_config = AutoConfig.for_model(decoder_type, **decoder_config) + + self.backbone_config = backbone_config + self.decoder_config = decoder_config + # main feature dimension for the model + self.fpn_feature_size = fpn_feature_size + self.mask_feature_size = mask_feature_size + # initializer + self.init_std = init_std + self.init_xavier_std = init_xavier_std + # Hungarian matcher && loss + self.cross_entropy_weight = cross_entropy_weight + self.dice_weight = dice_weight + self.mask_weight = mask_weight + self.use_auxiliary_loss = use_auxiliary_loss + self.no_object_weight = no_object_weight + + self.num_attention_heads = self.decoder_config.encoder_attention_heads + self.num_hidden_layers = self.decoder_config.num_hidden_layers + super().__init__(**kwargs) + + @classmethod + def from_backbone_and_decoder_configs( + cls, backbone_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ): + """Instantiate a [`MaskFormerConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model + configuration. + + Args: + backbone_config ([`PretrainedConfig`]): + The backbone configuration. + decoder_config ([`PretrainedConfig`]): + The transformer decoder configuration to use. + + Returns: + [`MaskFormerConfig`]: An instance of a configuration object + """ + return cls( + backbone_config=backbone_config.to_dict(), + decoder_config=decoder_config.to_dict(), + **kwargs, + ) + + def to_dict(self) -> Dict[str, any]: + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["backbone_config"] = self.backbone_config.to_dict() + output["decoder_config"] = self.decoder_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000..d4041ed59aadf --- /dev/null +++ b/src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,727 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from pathlib import Path +from pprint import pformat +from typing import Any, Dict, Iterator, List, Set, Tuple + +import torch +import torchvision.transforms as T +from PIL import Image +from torch import Tensor, nn + +import requests +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog +from detectron2.projects.deeplab import add_deeplab_config +from transformers.models.maskformer.feature_extraction_maskformer import MaskFormerFeatureExtractor +from transformers.models.maskformer.modeling_maskformer import ( + MaskFormerConfig, + MaskFormerForInstanceSegmentation, + MaskFormerForInstanceSegmentationOutput, + MaskFormerModel, + MaskFormerModelOutput, +) +from transformers.utils import logging + + +StateDict = Dict[str, Tensor] + +logging.set_verbosity_info() +logger = logging.get_logger() + +torch.manual_seed(0) + + +class TrackedStateDict: + def __init__(self, to_track: Dict): + """This class "tracks" a python dictionary by keeping track of which item is accessed. + + Args: + to_track (Dict): The dictionary we wish to track + """ + self.to_track = to_track + self._seen: Set[str] = set() + + def __getitem__(self, key: str) -> Any: + return self.to_track[key] + + def __setitem__(self, key: str, item: Any): + self._seen.add(key) + self.to_track[key] = item + + def diff(self) -> List[str]: + """This method returns a set difference between the keys in the tracked state dict and the one we have access so far. + This is an effective method to check if we have update all the keys + + Returns: + List[str]: List of keys not yet updated + """ + return set(list(self.to_track.keys())) - self._seen + + def copy(self) -> Dict: + # proxy the call to the internal dictionary + return self.to_track.copy() + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + img_data = requests.get(url, stream=True).raw + im = Image.open(img_data) + return im + + +@dataclass +class Args: + """Fake command line arguments needed by maskformer/detectron implementation""" + + config_file: str + + +def setup_cfg(args: Args): + # load config from file and command-line arguments + cfg = get_cfg() + add_deeplab_config(cfg) + add_mask_former_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.freeze() + return cfg + + +class OriginalMaskFormerConfigToOursConverter: + def __call__(self, original_config: object) -> MaskFormerConfig: + + model = original_config.MODEL + mask_former = model.MASK_FORMER + swin = model.SWIN + + dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0]) + id2label = {idx: label for idx, label in enumerate(dataset_catalog.stuff_classes)} + label2id = {label: idx for idx, label in id2label.items()} + + config: MaskFormerConfig = MaskFormerConfig( + fpn_feature_size=model.SEM_SEG_HEAD.CONVS_DIM, + mask_feature_size=model.SEM_SEG_HEAD.MASK_DIM, + num_labels=model.SEM_SEG_HEAD.NUM_CLASSES, + no_object_weight=mask_former.NO_OBJECT_WEIGHT, + num_queries=mask_former.NUM_OBJECT_QUERIES, + backbone_config=dict( + pretrain_img_size=swin.PRETRAIN_IMG_SIZE, + image_size=swin.PRETRAIN_IMG_SIZE, + in_channels=3, + patch_size=swin.PATCH_SIZE, + embed_dim=swin.EMBED_DIM, + depths=swin.DEPTHS, + num_heads=swin.NUM_HEADS, + window_size=swin.WINDOW_SIZE, + drop_path_rate=swin.DROP_PATH_RATE, + model_type="swin", + ), + dice_weight=mask_former.DICE_WEIGHT, + ce_weight=1.0, + mask_weight=mask_former.MASK_WEIGHT, + decoder_config=dict( + model_type="detr", + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=mask_former.DEC_LAYERS, + decoder_ffn_dim=mask_former.DIM_FEEDFORWARD, + decoder_attention_heads=mask_former.NHEADS, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + d_model=mask_former.HIDDEN_DIM, + dropout=mask_former.DROPOUT, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + init_xavier_std=1.0, + scale_embedding=False, + auxiliary_loss=False, + dilation=False, + # default pretrained config values + ), + id2label=id2label, + label2id=label2id, + ) + + return config + + +class OriginalMaskFormerConfigToFeatureExtractorConverter: + def __call__(self, original_config: object) -> MaskFormerFeatureExtractor: + model = original_config.MODEL + model_input = original_config.INPUT + + return MaskFormerFeatureExtractor( + image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), + image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), + size=model_input.MIN_SIZE_TEST, + max_size=model_input.MAX_SIZE_TEST, + size_divisibility=32, # 32 is required by swin + ) + + +class OriginalMaskFormerCheckpointToOursConverter: + def __init__(self, original_model: nn.Module, config: MaskFormerConfig): + self.original_model = original_model + self.config = config + + def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict): + for (src_key, dst_key) in renamed_keys: + dst_state_dict[dst_key] = src_state_dict.pop(src_key) + + def replace_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: MaskFormerConfig): + dst_prefix: str = "pixel_level_module.encoder" + src_prefix: str = "backbone" + + renamed_keys = [ + ( + f"{src_prefix}.patch_embed.proj.weight", + f"{dst_prefix}.model.embeddings.patch_embeddings.projection.weight", + ), + (f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.model.embeddings.patch_embeddings.projection.bias"), + (f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.model.embeddings.norm.weight"), + (f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.model.embeddings.norm.bias"), + ] + num_layers = len(config.backbone_config.depths) + for layer_idx in range(num_layers): + for block_idx in range(config.backbone_config.depths[layer_idx]): + renamed_keys.extend( + [ # src, dst + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table", + ), + ] + ) + # now we need to handle the attentions + # read in weights + bias of input projection layer of cross-attention + + src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"] + src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"] + + size = src_att_weight.shape[0] + offset = size // 3 + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight" + ] = src_att_weight[:offset, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias" + ] = src_att_bias[:offset] + + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight" + ] = src_att_weight[offset : offset * 2, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias" + ] = src_att_bias[offset : offset * 2] + + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight" + ] = src_att_weight[-offset:, :] + dst_state_dict[ + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias" + ] = src_att_bias[-offset:] + + # let's pop them + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight") + src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias") + # proj + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias", + ), + ] + ) + + # second norm + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias", + ), + ] + ) + + # mlp + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias", + ), + ] + ) + + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index", + ) + ] + ) + + if layer_idx < num_layers - 1: + # patch merging + renamed_keys.extend( + [ + ( + f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.reduction.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.weight", + ), + ( + f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias", + f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.bias", + ), + ] + ) + + # hidden states norms + renamed_keys.extend( + [ + ( + f"{src_prefix}.norm{layer_idx}.weight", + f"{dst_prefix}.hidden_states_norms.{layer_idx}.weight", + ), + ( + f"{src_prefix}.norm{layer_idx}.bias", + f"{dst_prefix}.hidden_states_norms.{layer_idx}.bias", + ), + ] + ) + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "pixel_level_module.decoder" + src_prefix: str = "sem_seg_head.pixel_decoder" + + self.replace_backbone(dst_state_dict, src_state_dict, self.config) + + def rename_keys_for_conv(detectron_conv: str, mine_conv: str): + return [ + (f"{detectron_conv}.weight", f"{mine_conv}.0.weight"), + # 2 cuz the have act in the middle -> rename it + (f"{detectron_conv}.norm.weight", f"{mine_conv}.1.weight"), + (f"{detectron_conv}.norm.bias", f"{mine_conv}.1.bias"), + ] + + renamed_keys = [ + (f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"), + (f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"), + # the layers in the original one are in reverse order, stem is the last one! + ] + + renamed_keys.extend(rename_keys_for_conv(f"{src_prefix}.layer_4", f"{dst_prefix}.fpn.stem")) + + # add all the fpn layers (here we need some config parameters to know the size in advance) + for src_i, dst_i in zip(range(3, 0, -1), range(0, 3)): + renamed_keys.extend( + rename_keys_for_conv(f"{src_prefix}.adapter_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.proj") + ) + renamed_keys.extend( + rename_keys_for_conv(f"{src_prefix}.layer_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.block") + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def rename_keys_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder" + src_prefix: str = "sem_seg_head.predictor.transformer.decoder" + # not sure why we are not popping direcetly here! + # here we list all keys to be renamed (original name on the left, our name on the right) + rename_keys = [] + for i in range(self.config.decoder_config.decoder_layers): + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append( + ( + f"{src_prefix}.layers.{i}.self_attn.out_proj.weight", + f"{dst_prefix}.layers.{i}.self_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.layers.{i}.self_attn.out_proj.bias", + f"{dst_prefix}.layers.{i}.self_attn.out_proj.bias", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.layers.{i}.multihead_attn.out_proj.weight", + f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.weight", + ) + ) + rename_keys.append( + ( + f"{src_prefix}.layers.{i}.multihead_attn.out_proj.bias", + f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.bias", + ) + ) + rename_keys.append((f"{src_prefix}.layers.{i}.linear1.weight", f"{dst_prefix}.layers.{i}.fc1.weight")) + rename_keys.append((f"{src_prefix}.layers.{i}.linear1.bias", f"{dst_prefix}.layers.{i}.fc1.bias")) + rename_keys.append((f"{src_prefix}.layers.{i}.linear2.weight", f"{dst_prefix}.layers.{i}.fc2.weight")) + rename_keys.append((f"{src_prefix}.layers.{i}.linear2.bias", f"{dst_prefix}.layers.{i}.fc2.bias")) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm1.weight", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.weight") + ) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm1.bias", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.bias") + ) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm2.weight", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.weight") + ) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm2.bias", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.bias") + ) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm3.weight", f"{dst_prefix}.layers.{i}.final_layer_norm.weight") + ) + rename_keys.append( + (f"{src_prefix}.layers.{i}.norm3.bias", f"{dst_prefix}.layers.{i}.final_layer_norm.bias") + ) + + return rename_keys + + def replace_q_k_v_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder" + src_prefix: str = "sem_seg_head.predictor.transformer.decoder" + for i in range(self.config.decoder_config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_weight") + in_proj_bias = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] + # read in weights + bias of input projection layer of cross-attention + in_proj_weight_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_weight") + in_proj_bias_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_bias") + # next, add query, keys and values (in that order) of cross-attention to the state dict + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :] + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256] + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[ + 256:512, : + ] + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512] + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :] + dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:] + + def replace_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module.decoder" + src_prefix: str = "sem_seg_head.predictor.transformer.decoder" + renamed_keys = self.rename_keys_in_detr_decoder(dst_state_dict, src_state_dict) + # add more + renamed_keys.extend( + [ + (f"{src_prefix}.norm.weight", f"{dst_prefix}.layernorm.weight"), + (f"{src_prefix}.norm.bias", f"{dst_prefix}.layernorm.bias"), + ] + ) + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + self.replace_q_k_v_in_detr_decoder(dst_state_dict, src_state_dict) + + def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + dst_prefix: str = "transformer_module" + src_prefix: str = "sem_seg_head.predictor" + + self.replace_detr_decoder(dst_state_dict, src_state_dict) + + renamed_keys = [ + (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"), + (f"{src_prefix}.input_proj.weight", f"{dst_prefix}.input_projection.weight"), + (f"{src_prefix}.input_proj.bias", f"{dst_prefix}.input_projection.bias"), + ] + + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def replace_instance_segmentation_module(self, dst_state_dict: StateDict, src_state_dict: StateDict): + # NOTE in our case we don't have a prefix, thus we removed the "." from the keys later on! + dst_prefix: str = "" + src_prefix: str = "sem_seg_head.predictor" + + renamed_keys = [ + (f"{src_prefix}.class_embed.weight", f"{dst_prefix}class_predictor.weight"), + (f"{src_prefix}.class_embed.bias", f"{dst_prefix}class_predictor.bias"), + ] + + mlp_len = 3 + for i in range(mlp_len): + renamed_keys.extend( + [ + (f"{src_prefix}.mask_embed.layers.{i}.weight", f"{dst_prefix}mask_embedder.{i}.0.weight"), + (f"{src_prefix}.mask_embed.layers.{i}.bias", f"{dst_prefix}mask_embedder.{i}.0.bias"), + ] + ) + logger.info(f"Replacing keys {pformat(renamed_keys)}") + self.pop_all(renamed_keys, dst_state_dict, src_state_dict) + + def convert(self, mask_former: MaskFormerModel) -> MaskFormerModel: + dst_state_dict = TrackedStateDict(mask_former.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_pixel_module(dst_state_dict, src_state_dict) + self.replace_transformer_module(dst_state_dict, src_state_dict) + + logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}") + logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}") + logger.info("🙌 Done") + + mask_former.load_state_dict(dst_state_dict) + + return mask_former + + def convert_instance_segmentation( + self, mask_former: MaskFormerForInstanceSegmentation + ) -> MaskFormerForInstanceSegmentation: + dst_state_dict = TrackedStateDict(mask_former.state_dict()) + src_state_dict = self.original_model.state_dict() + + self.replace_instance_segmentation_module(dst_state_dict, src_state_dict) + + mask_former.load_state_dict(dst_state_dict) + + return mask_former + + @staticmethod + def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]: + checkpoints: List[Path] = checkpoints_dir.glob("**/*.pkl") + + for checkpoint in checkpoints: + logger.info(f"💪 Converting {checkpoint.stem}") + # find associated config file + config: Path = config_dir / checkpoint.parents[0].stem / "swin" / f"{checkpoint.stem}.yaml" + + yield config, checkpoint + + +def test(original_model, our_model: MaskFormerForInstanceSegmentation): + with torch.no_grad(): + + original_model = original_model.eval() + our_model = our_model.eval() + + im = prepare_img() + + tr = T.Compose( + [ + T.Resize((384, 384)), + T.ToTensor(), + T.Normalize( + mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0, + std=torch.tensor([58.395, 57.120, 57.375]) / 255.0, + ), + ], + ) + + x = tr(im).unsqueeze(0) + + original_model_backbone_features = original_model.backbone(x.clone()) + + our_model_output: MaskFormerModelOutput = our_model.model(x.clone(), output_hidden_states=True) + + for original_model_feature, our_model_feature in zip( + original_model_backbone_features.values(), our_model_output.encoder_hidden_states + ): + + assert torch.allclose( + original_model_feature, our_model_feature, atol=1e-3 + ), "The backbone features are not the same." + + original_model_pixel_out = original_model.sem_seg_head.pixel_decoder.forward_features( + original_model_backbone_features + ) + + assert torch.allclose( + original_model_pixel_out[0], our_model_output.pixel_decoder_last_hidden_state, atol=1e-4 + ), "The pixel decoder feature are not the same" + + # let's test the full model + original_model_out = original_model([{"image": x.squeeze(0)}]) + + original_segmentation = original_model_out[0]["sem_seg"] + + our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x) + + feature_extractor = MaskFormerFeatureExtractor() + + our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384)) + + assert torch.allclose( + original_segmentation, our_segmentation, atol=1e-3 + ), "The segmentation image is not the same." + + logger.info("✅ Test passed!") + + +def get_name(checkpoint_file: Path): + model_name_raw: str = checkpoint_file.stem + # model_name_raw is something like maskformer_panoptic_swin_base_IN21k_384_bs64_554k + parent_name: str = checkpoint_file.parents[0].stem + backbone = "swin" + dataset = "" + if "coco" in parent_name: + dataset = "coco" + elif "ade" in parent_name: + dataset = "ade" + else: + raise ValueError(f"{parent_name} must be wrong since we didn't find 'coco' or 'ade' in it ") + + backbone_types = ["tiny", "small", "base", "large"] + + backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0] + + model_name = f"maskformer-{backbone}-{backbone_type}-{dataset}" + + return model_name + + +if __name__ == "__main__": + + parser = ArgumentParser( + description="Command line to convert the original maskformers (with swin backbone) to our implementations." + ) + + parser.add_argument( + "--checkpoints_dir", + type=Path, + help="A directory containing the model's checkpoints. The directory has to have the following structure: //.pkl", + ) + parser.add_argument( + "--configs_dir", + type=Path, + help="A directory containing the model's configs, see detectron2 doc. The directory has to have the following structure: //.yaml", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + required=True, + type=Path, + help="Path to the folder to output PyTorch models.", + ) + parser.add_argument( + "--maskformer_dir", + required=True, + type=Path, + help="A path to MaskFormer's original implementation directory. You can download from here: https://github.com/facebookresearch/MaskFormer", + ) + + args = parser.parse_args() + + checkpoints_dir: Path = args.checkpoints_dir + config_dir: Path = args.configs_dir + save_directory: Path = args.pytorch_dump_folder_path + maskformer_dir: Path = args.maskformer_dir + # append the path to the parents to maskformer dir + sys.path.append(str(maskformer_dir.parent)) + # and import what's needed + from MaskFormer.mask_former import add_mask_former_config + from MaskFormer.mask_former.mask_former_model import MaskFormer as OriginalMaskFormer + + if not save_directory.exists(): + save_directory.mkdir(parents=True) + + for config_file, checkpoint_file in OriginalMaskFormerCheckpointToOursConverter.using_dirs( + checkpoints_dir, config_dir + ): + + feature_extractor = OriginalMaskFormerConfigToFeatureExtractorConverter()( + setup_cfg(Args(config_file=config_file)) + ) + + original_config = setup_cfg(Args(config_file=config_file)) + mask_former_kwargs = OriginalMaskFormer.from_config(original_config) + + original_model = OriginalMaskFormer(**mask_former_kwargs).eval() + + DetectionCheckpointer(original_model).load(str(checkpoint_file)) + + config: MaskFormerConfig = OriginalMaskFormerConfigToOursConverter()(original_config) + + mask_former = MaskFormerModel(config=config).eval() + + converter = OriginalMaskFormerCheckpointToOursConverter(original_model, config) + + maskformer = converter.convert(mask_former) + + mask_former_for_instance_segmentation = MaskFormerForInstanceSegmentation(config=config).eval() + + mask_former_for_instance_segmentation.model = mask_former + mask_former_for_instance_segmentation = converter.convert_instance_segmentation( + mask_former_for_instance_segmentation + ) + + test(original_model, mask_former_for_instance_segmentation) + + model_name = get_name(checkpoint_file) + logger.info(f"🪄 Saving {model_name}") + + feature_extractor.save_pretrained(save_directory / model_name) + mask_former_for_instance_segmentation.save_pretrained(save_directory / model_name) + + feature_extractor.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add model", + use_temp_dir=True, + ) + mask_former_for_instance_segmentation.push_to_hub( + repo_path_or_name=save_directory / model_name, + commit_message="Add model", + use_temp_dir=True, + ) diff --git a/src/transformers/models/maskformer/feature_extraction_maskformer.py b/src/transformers/models/maskformer/feature_extraction_maskformer.py new file mode 100644 index 0000000000000..7a0e21a93a023 --- /dev/null +++ b/src/transformers/models/maskformer/feature_extraction_maskformer.py @@ -0,0 +1,569 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for MaskFormer.""" + +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...file_utils import TensorType, is_torch_available +from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor +from ...utils import logging + + +if is_torch_available(): + import torch + from torch import Tensor, nn + from torch.nn.functional import interpolate + + if TYPE_CHECKING: + from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput + +logger = logging.get_logger(__name__) + + +class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a MaskFormer feature extractor. The feature extractor can be used to prepare image(s) and optional + targets for the model. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 800): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + max_size (`int`, *optional*, defaults to 1333): + The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is + set to `True`. + size_divisibility (`int`, *optional*, defaults to 32): + Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in + Swin Transformer. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*, default to 255): + Value of the index (label) to ignore. + + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + do_resize=True, + size=800, + max_size=1333, + size_divisibility=32, + do_normalize=True, + image_mean=None, + image_std=None, + ignore_index=255, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.max_size = max_size + self.size_divisibility = size_divisibility + self.ignore_index = ignore_index + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean + self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std + + def _resize(self, image, size, target=None, max_size=None): + """ + Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int, + smaller edge of the image will be matched to this number. + + If given, also resize the target accordingly. + """ + if not isinstance(image, Image.Image): + image = self.to_pil_image(image) + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + width, height = image_size + if max_size is not None: + min_original_size = float(min((width, height))) + max_original_size = float(max((width, height))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (width <= height and width == size) or (height <= width and height == size): + return (height, width) + + if width < height: + output_width = size + output_height = int(size * height / width) + else: + output_height = size + output_width = int(size * width / height) + + return (output_height, output_width) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size + else: + # size returned must be (width, height) since we use PIL to resize images + # so we revert the tuple + return get_size_with_aspect_ratio(image_size, size, max_size)[::-1] + + width, height = get_size(image.size, size, max_size) + + if self.size_divisibility > 0: + height = int(np.ceil(height / self.size_divisibility)) * self.size_divisibility + width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility + + size = (width, height) + rescaled_image = self.resize(image, size=size) + + has_target = target is not None + + if has_target: + target = target.copy() + # store original_size + target["original_size"] = image.size + if "masks" in target: + masks = torch.from_numpy(target["masks"])[:, None].float() + # use PyTorch as current workaround + # TODO replace by self.resize + interpolated_masks = ( + nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5 + ).float() + target["masks"] = interpolated_masks.numpy() + + return rescaled_image, target + + def __call__( + self, + images: ImageInput, + annotations: Union[List[Dict], List[List[Dict]]] = None, + pad_and_return_pixel_mask: Optional[bool] = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s) and optional annotations. Images are by default + padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are + real/which are padding. + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + annotations (`Dict`, `List[Dict]`, *optional*): + The corresponding annotations as dictionary of numpy arrays with the following keys: + - **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`. + - **labels** (`np.ndarray`) The target labels of shape `(num_classes)`. + + pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if + *"pixel_mask"* is in `self.model_input_names`). + - **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a + model (when `annotations` are provided). + - **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when + `annotations` are provided). + """ + # Input type checking for clearer error + + valid_images = False + valid_annotations = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + if annotations is not None: + annotations = [annotations] + + # Check that annotations has a valid type + if annotations is not None: + valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0] + if not valid_annotations: + raise ValueError( + "Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)." + "The annotations must be numpy arrays in the following format:" + "{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}" + ) + + # transformations (resizing + normalization) + if self.do_resize and self.size is not None: + if annotations is not None: + for idx, (image, target) in enumerate(zip(images, annotations)): + image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size) + images[idx] = image + annotations[idx] = target + else: + for idx, image in enumerate(images): + images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0] + + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + # NOTE I will be always forced to pad them them since they have to be stacked in the batch dim + encoded_inputs = self.encode_inputs( + images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors + ) + + # Convert to TensorType + tensor_type = return_tensors + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + if not tensor_type == TensorType.PYTORCH: + raise ValueError("Only PyTorch is supported for the moment.") + else: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + + return encoded_inputs + + def _max_by_axis(self, the_list: List[List[int]]) -> List[int]: + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + def encode_inputs( + self, + pixel_values_list: List["torch.Tensor"], + annotations: Optional[List[Dict]] = None, + pad_and_return_pixel_mask: Optional[bool] = True, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + Args: + pixel_values_list (`List[torch.Tensor]`): + List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + annotations (`Dict`, `List[Dict]`, *optional*): + The corresponding annotations as dictionary of numpy arrays with the following keys: + - **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`. + - **labels** (`np.ndarray`) The target labels of shape `(num_classes)`. + + pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` + objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if + *"pixel_mask"* is in `self.model_input_names`). + - **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a + model (when `annotations` are provided). + - **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when + `annotations` are provided). + """ + + max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list]) + channels, height, width = max_size + pixel_values = [] + pixel_mask = [] + mask_labels = [] + class_labels = [] + for idx, image in enumerate(pixel_values_list): + # create padded image + if pad_and_return_pixel_mask: + padded_image = np.zeros((channels, height, width), dtype=np.float32) + padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image) + image = padded_image + pixel_values.append(image) + # if we have a target, pad it + if annotations: + annotation = annotations[idx] + masks = annotation["masks"] + if pad_and_return_pixel_mask: + padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype) + padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks) + masks = padded_masks + mask_labels.append(masks) + class_labels.append(annotation["labels"]) + if pad_and_return_pixel_mask: + # create pixel mask + mask = np.zeros((height, width), dtype=np.int64) + mask[: image.shape[1], : image.shape[2]] = True + pixel_mask.append(mask) + + # return as BatchFeature + data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + + if annotations: + data["mask_labels"] = mask_labels + data["class_labels"] = class_labels + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs + + def post_process_segmentation( + self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only + supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + + target_size (`Tuple[int, int]`, *optional*): + If set, the `masks_queries_logits` will be resized to `target_size`. + + Returns: + `torch.Tensor`: + A tensor of shape (`batch_size, num_labels, height, width`). + """ + # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] + class_queries_logits = outputs.class_queries_logits + # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_queries_logits = outputs.masks_queries_logits + if target_size is not None: + masks_queries_logits = interpolate( + masks_queries_logits, + size=target_size, + mode="bilinear", + align_corners=False, + ) + # remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_probs = masks_queries_logits.sigmoid() + # now we want to sum over the queries, + # $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $ + # where $ softmax(p) \in R^{q, c} $ is the mask classes + # and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities + # b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + return segmentation + + def remove_low_and_no_objects(self, masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` + and `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + + Returns: + `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the + region < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + def post_process_semantic_segmentation( + self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None + ) -> "torch.Tensor": + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into semantic segmentation predictions. Only + supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + + Returns: + `torch.Tensor`: A tensor of shape `batch_size, height, width`. + """ + segmentation = self.post_process_segmentation(outputs, target_size) + semantic_segmentation = segmentation.argmax(dim=1) + return semantic_segmentation + + def post_process_panoptic_segmentation( + self, + outputs: "MaskFormerForInstanceSegmentationOutput", + object_mask_threshold: float = 0.8, + overlap_mask_area_threshold: float = 0.8, + is_thing_map: Optional[Dict[int, bool]] = None, + ) -> List[Dict]: + """ + Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation + predictions. Only supports PyTorch. + + Args: + outputs ([`MaskFormerForInstanceSegmentationOutput`]): + The outputs from [`MaskFormerForInstanceSegmentation`]. + object_mask_threshold (`float`, *optional*, defaults to 0.8): + The object mask threshold. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to use. + is_thing_map (`Dict[int, bool]`, *optional*): + Dictionary mapping class indices to either `True` or `False`, depending on whether or not they are a + thing. If not set, defaults to the `is_thing_map` of COCO panoptic. + + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`. + - **segments** -- a dictionary with the following keys + - **id** -- an integer representing the `segment_id`. + - **category_id** -- an integer representing the segment's label. + - **is_thing** -- a boolean, `True` if `category_id` was in `is_thing_map`, `False` otherwise. + """ + + if is_thing_map is None: + logger.warning("`is_thing_map` unset. Default to COCO.") + # default to is_thing_map of COCO panoptic + is_thing_map = {i: i <= 90 for i in range(201)} + # class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1] + class_queries_logits = outputs.class_queries_logits + # keep track of the number of labels, subtract -1 for null class + num_labels = class_queries_logits.shape[-1] - 1 + # masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH] + masks_queries_logits = outputs.masks_queries_logits + # since all images are padded, they all have the same spatial dimensions + _, _, height, width = masks_queries_logits.shape + # for each query, the best scores and their indeces + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + # pred_scores and pred_labels shape = [BATH,NUM_QUERIES] + mask_probs = masks_queries_logits.sigmoid() + # mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH] + # now, we need to iterate over the batch size to correctly process the segmentation we got from the queries using our thresholds. Even if the original predicted masks have the same shape across the batch, they won't after thresholding so batch-wise operations are impossible + results: List[Dict[str, Tensor]] = [] + for (mask_probs, pred_scores, pred_labels) in zip(mask_probs, pred_scores, pred_labels): + mask_probs, pred_scores, pred_labels = self.remove_low_and_no_objects( + mask_probs, pred_scores, pred_labels, object_mask_threshold, num_labels + ) + we_detect_something = mask_probs.shape[0] > 0 + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if we_detect_something: + current_segment_id = 0 + # weight each mask by its score + mask_probs *= pred_scores.view(-1, 1, 1) + # find out for each pixel what is the most likely class to be there + mask_labels = mask_probs.argmax(0) + # mask_labels shape = [H,W] where each pixel has a class label + stuff_memory_list: Dict[str, int] = {} + # this is a map between stuff and segments id, the used it to keep track of the instances of one class + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + # check if pred_class is not a "thing", so it can be merged with other instance. For example, class "sky" cannot have more then one instance + is_stuff = not is_thing_map[pred_class] + # get the mask associated with the k class + mask_k = mask_labels == k + # create the area, since bool we just need to sum :) + mask_k_area = mask_k.sum() + # this is the area of all the stuff in query k + # TODO not 100%, why are the taking the k query here???? + original_area = (mask_probs[k] >= 0.5).sum() + + mask_does_exist = mask_k_area > 0 and original_area > 0 + + if mask_does_exist: + # find out how much of the all area mask_k is using + area_ratio = mask_k_area / original_area + mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold + + if mask_k_is_overlapping_enough: + # merge stuff regions + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + # then we update out mask with the current segment + segmentation[mask_k] = current_segment_id + segments.append( + { + "id": current_segment_id, + "category_id": pred_class, + "is_thing": not is_stuff, + } + ) + if is_stuff: + stuff_memory_list[pred_class] = current_segment_id + results.append({"segmentation": segmentation, "segments": segments}) + return results diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py new file mode 100644 index 0000000000000..66b34c833fafc --- /dev/null +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -0,0 +1,2499 @@ +# coding=utf-8 +# Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch MaskFormer model.""" + +import collections.abc +import math +import random +from dataclasses import dataclass +from numbers import Number +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor, nn + +from transformers.utils import logging + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + replace_return_docstrings, + requires_backends, +) +from ...modeling_outputs import BaseModelOutputWithCrossAttentions +from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ..detr import DetrConfig +from ..swin import SwinConfig +from .configuration_maskformer import MaskFormerConfig + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "MaskFormerConfig" +_CHECKPOINT_FOR_DOC = "facebook/maskformer-swin-base-ade" +_FEAT_EXTRACTOR_FOR_DOC = "MaskFormerFeatureExtractor" + +MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/maskformer-swin-base-ade", + # See all MaskFormer models at https://huggingface.co/models?filter=maskformer +] + + +@dataclass +class MaskFormerSwinModelOutputWithPooling(ModelOutput): + """ + Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a mean pooling operation. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the + `forward` method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MaskFormerSwinBaseModelOutput(ModelOutput): + """ + Class for SwinEncoder's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward` + method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput +class DetrDecoderOutput(BaseModelOutputWithCrossAttentions): + """ + Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions, + namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them + gone through a layernorm. This is useful when training the model with auxiliary decoding losses. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`): + Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a + layernorm. + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + + +@dataclass +class MaskFormerPixelLevelModuleOutput(ModelOutput): + """ + MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the + `encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature + Pyramid Network (FPN). + + The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state` + as **pixel embeddings** + + Args: + encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at + the output of each stage. + decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the decoder. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at + the output of each stage. + """ + + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + decoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class MaskFormerPixelDecoderOutput(ModelOutput): + """ + MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state + and (optionally) the hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MaskFormerModelOutput(ModelOutput): + """ + Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits. + + Args: + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN). + transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden states (final feature map) of the last stage of the transformer decoder model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the + transformer decoder at the output of each stage. + hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and + `decoder_hidden_states` + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MaskFormerForInstanceSegmentationOutput(ModelOutput): + """ + Class for outputs of [`MaskFormerForInstanceSegmentation`]. + + This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_segmentation`] or + [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see + [`~MaskFormerFeatureExtractor] for details regarding usage. + + Args: + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the encoder model (backbone). + pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN). + transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Last hidden states (final feature map) of the last stage of the transformer decoder model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder + model at the output of each stage. + pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel + decoder model at the output of each stage. + transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the transformer decoder at the output + of each stage. + hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and + `decoder_hidden_states`. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + class_queries_logits: torch.FloatTensor = None + masks_queries_logits: torch.FloatTensor = None + auxiliary_logits: torch.FloatTensor = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor: + """ + An utility function that upsamples `pixel_values` to match the dimension of `like`. + + Args: + pixel_values (`torch.Tensor`): + The tensor we wish to upsample. + like (`torch.Tensor`): + The tensor we wish to use as size target. + mode (str, *optional*, defaults to `"bilinear"`): + The interpolation mode. + + Returns: + `torch.Tensor`: The upsampled tensor + """ + _, _, height, width = like.shape + upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False) + return upsampled + + +# refactored from original implementation +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +# refactored from original implementation +def sigmoid_focal_loss( + inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2 +) -> Tensor: + r""" + Focal loss proposed in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) originally used in + RetinaNet. The loss is computed as follows: + + $$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$ + + where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss + + Please refer to equation (1,2,3) of the paper for a better understanding. + + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss. + """ + criterion = nn.BCEWithLogitsLoss(reduction="none") + probs = inputs.sigmoid() + cross_entropy_loss = criterion(inputs, labels) + p_t = probs * labels + (1 - probs) * (1 - labels) + loss = cross_entropy_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * labels + (1 - alpha) * (1 - labels) + loss = alpha_t * loss + + loss = loss.mean(1).sum() / num_masks + return loss + + +# refactored from original implementation +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) + # using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +# refactored from original implementation +def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor: + r""" + A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha (float, *optional*, defaults to 0.25): + Weighting factor in range (0,1) to balance positive vs negative examples. + gamma (float, *optional*, defaults to 2.0): + Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples. + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + if alpha < 0: + raise ValueError("alpha must be positive") + + height_and_width = inputs.shape[1] + + criterion = nn.BCEWithLogitsLoss(reduction="none") + prob = inputs.sigmoid() + cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) + focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos + focal_pos *= alpha + + cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) + + focal_neg = (prob**gamma) * cross_entropy_loss_neg + focal_neg *= 1 - alpha + + loss = torch.einsum("nc,mc->nm", focal_pos, labels) + torch.einsum("nc,mc->nm", focal_neg, (1 - labels)) + + return loss / height_and_width + + +# Copied from transformers.models.vit.modeling_vit.to_2tuple +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) + windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + return windows + + +# Copied from transformers.models.swin.modeling_swin.drop_path +def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = input.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return input * random_tensor + + +class MaskFormerSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = MaskFormerSwinPatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.embed_dim, + ) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values): + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class MaskFormerSwinPatchEmbeddings(nn.Module): + """ + Image to Patch Embedding, including padding. + """ + + def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values): + _, _, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings_flat = embeddings.flatten(2).transpose(1, 2) + + return embeddings_flat, output_dimensions + + +class MaskFormerSwinPatchMerging(nn.Module): + """ + Patch Merging Layer for maskformer model. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, width, height): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature, input_dimensions): + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin +class MaskFormerSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None, scale_by_keep=True): + super(MaskFormerSwinDropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, input): + return drop_path(input, self.drop_prob, self.training, self.scale_by_keep) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin +class MaskFormerSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = to_2tuple(config.window_size) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin +class MaskFormerSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin +class MaskFormerSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads): + super().__init__() + self.self = MaskFormerSwinSelfAttention(config, dim, num_heads) + self.output = MaskFormerSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin +class MaskFormerSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin +class MaskFormerSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class MaskFormerSwinBlock(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = MaskFormerSwinAttention(config, dim, num_heads) + self.drop_path = ( + MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = MaskFormerSwinIntermediate(config, dim) + self.output = MaskFormerSwinOutput(config, dim) + + def get_attn_mask(self, input_resolution): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + height, width = input_resolution + img_mask = torch.zeros((1, height, width, 1)) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_left = pad_top = 0 + pad_rigth = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + height, width = input_dimensions + batch_size, dim, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask((height_pad, width_pad)) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + self_attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + outputs = (layer_output,) + outputs + + return outputs + + +class MaskFormerSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + MaskFormerSwinBlock( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False + ): + all_hidden_states = () if output_hidden_states else None + + height, width = input_dimensions + for i, block_module in enumerate(self.blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = block_hidden_states[0] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + return hidden_states, output_dimensions, all_hidden_states + + +class MaskFormerSwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + self.layers = nn.ModuleList( + [ + MaskFormerSwinLayer( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + input_dimensions, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_input_dimensions = () + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, layer_head_mask + ) + else: + layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + output_hidden_states, + ) + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + if output_hidden_states: + all_hidden_states += (layer_all_hidden_states,) + + hidden_states = layer_hidden_states + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return MaskFormerSwinBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + hidden_states_spatial_dimensions=all_input_dimensions, + attentions=all_self_attentions, + ) + + +class MaskFormerSwinModel(nn.Module, ModuleUtilsMixin): + def __init__(self, config, add_pooling_layer=True): + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = MaskFormerSwinEmbeddings(config) + self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs.last_hidden_state + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions + + return MaskFormerSwinModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + hidden_states_spatial_dimensions=hidden_states_spatial_dimensions, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.detr.modeling_detr.DetrAttention +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer +class DetrDecoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = DetrAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_embeddings (`torch.FloatTensor`, *optional*): + position embeddings that are added to the queries and keys + in the cross-attention layer. + query_position_embeddings (`torch.FloatTensor`, *optional*): + position embeddings that are added to the queries and keys + in the self-attention layer. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + key_value_position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.detr.modeling_detr._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class DetrDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for DETR: + + - position_embeddings and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DetrConfig + """ + + def __init__(self, config: DetrConfig): + super().__init__() + self.config = config + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + + self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in DETR, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + position_embeddings=None, + query_position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + combined_attention_mask = None + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # optional intermediate hidden states + intermediate = () if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + hidden_states = self.layernorm(hidden_states) + intermediate += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] + if v is not None + ) + return DetrDecoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + ) + + +# refactored from original implementation +class MaskFormerHungarianMatcher(nn.Module): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0): + """Creates the matcher + + Params: + cost_class (float, *optional*, defaults to 1.0): + This is the relative weight of the classification error in the matching cost. + cost_mask (float, *optional*, defaults to 1.0): + This is the relative weight of the focal loss of the binary mask in the matching cost. + cost_dice (float, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs cant be 0") + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + @torch.no_grad() + def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]: + """Performs the matching + + Params: + masks_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, num_classes` with the + classification logits. + class_queries_logits (`torch.Tensor`): + A tensor` of dim `batch_size, num_queries, height, width` with the + predicted masks. + + class_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes` (where num_target_boxes is the number + of ground-truth objects in the target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor` of dim `num_target_boxes, height, width` containing the target + masks. + + Returns: + `List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes). + """ + indices: List[Tuple[np.array]] = [] + + preds_masks = masks_queries_logits + preds_probs = class_queries_logits.softmax(dim=-1) + # downsample all masks in one go -> save memory + mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest") + # iterate through batch size + for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -pred_probs[:, labels] + # flatten spatial dimension "q h w -> q (h w)" + num_queries, height, width = pred_mask.shape + pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W] + # same for target_mask "c h w -> c (h w)" + num_channels, height, width = target_mask.shape + target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W] + # compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES] + cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat) + # Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES] + cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + # do the assigmented using the hungarian algorithm in scipy + assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu()) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices + ] + return matched_indices + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + f"cost_class: {self.cost_class}", + f"cost_mask: {self.cost_mask}", + f"cost_dice: {self.cost_dice}", + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +# copied and adapted from original implementation +class MaskFormerLoss(nn.Module): + def __init__( + self, + num_classes: int, + matcher: MaskFormerHungarianMatcher, + weight_dict: Dict[str, float], + eos_coef: float, + ): + """ + The MaskFormer Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute + hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of + matched ground-truth / prediction (supervise class and mask) + + Args: + num_classes (`int`): + The number of classes. + matcher (`MaskFormerHungarianMatcher`): + A torch module that computes the assigments between the predictions and labels. + weight_dict (`Dict[str, float]`): + A dictionary of weights to be applied to the different losses. + eos_coef (`float`): + Weight to apply to the null class. + """ + + super().__init__() + requires_backends(self, ["scipy"]) + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array] + ) -> Dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_classes` + class_labels (`Dict[str, Tensor]`): + A tensor of shape `batch_size, num_classes` + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) + # shape = [BATCH, N_QUERIES] + target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) + # shape = [BATCH, N_QUERIES] + target_classes = torch.full( + (batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device + ) + target_classes[idx] = target_classes_o + # target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q" + pred_logits_permuted = pred_logits.permute(0, 2, 1) + loss_ce = criterion(pred_logits_permuted, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int + ) -> Dict[str, Tensor]: + """Compute the losses related to the masks using focal and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + mask_labels (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + indices (`Tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W] + pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W] + target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W] + target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W] + # upsample predictions to the target size, we have to add one dim to use interpolate + pred_masks = nn.functional.interpolate( + pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False + ) + pred_masks = pred_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(pred_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks), + "loss_dice": dice_loss(pred_masks, target_masks, num_masks), + } + return losses + + def _get_predictions_permutation_indices(self, indices): + # permute predictions following indices + batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = torch.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # permute labels following indices + batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = torch.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def get_loss(self, loss, outputs, labels, indices, num_masks): + loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} + if loss not in loss_map: + raise KeyError(f"{loss} not in loss_map") + return loss_map[loss](outputs, labels, indices, num_masks) + + def forward( + self, + masks_queries_logits: torch.Tensor, + class_queries_logits: torch.Tensor, + mask_labels: torch.Tensor, + class_labels: torch.Tensor, + auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, height, width` + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_classes` + mask_labels (`torch.Tensor`): + A tensor of shape `batch_size, num_classes, height, width` + class_labels (`torch.Tensor`): + A tensor of shape `batch_size, num_classes` + auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the + inner layers of the Detr's Decoder. + + Returns: + `Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains addional losses + for each auxiliary predictions. + """ + + # Retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + + # Compute the average number of target masks accross all nodes, for normalization purposes + num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device) + + # Compute all the requested losses + losses: Dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: + # Compute the average number of target masks accross all nodes, for normalization purposes + num_masks = class_labels.shape[0] + num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + return num_masks_pt + + +class MaskFormerSwinTransformerBackbone(nn.Module): + """ + This class uses [`MaskFormerSwinModel`] to reshape its `hidden_states` from (`batch_size, sequence_length, + hidden_size)` to (`batch_size, num_channels, height, width)`). + + Args: + config (`SwinConfig`): + The configuration used by [`MaskFormerSwinModel`]. + """ + + def __init__(self, config: SwinConfig): + super().__init__() + self.model = MaskFormerSwinModel(config) + self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(out_shape) for out_shape in self.outputs_shapes]) + + def forward(self, *args, **kwargs) -> List[Tensor]: + output = self.model(*args, **kwargs, output_hidden_states=True) + hidden_states_permuted: List[Tensor] = [] + # we need to reshape the hidden state to their original spatial dimensions + # skipping the embeddings + hidden_states: Tuple[Tuple[Tensor]] = output.hidden_states[1:] + # spatial dimensions contains all the heights and widths of each stage, including after the embeddings + spatial_dimensions: Tuple[Tuple[int, int]] = output.hidden_states_spatial_dimensions + for i, (hidden_state, (height, width)) in enumerate(zip(hidden_states, spatial_dimensions)): + norm = self.hidden_states_norms[i] + # the last element corespond to the layer's last block output but before patch merging + hidden_state_unpolled = hidden_state[-1] + hidden_state_norm = norm(hidden_state_unpolled) + # our pixel decoder (FPN) expect 3D tensors (features) + batch_size, _, hidden_size = hidden_state_norm.shape + # reshape our tensor "b (h w) d -> b d h w" + hidden_state_permuted = ( + hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous() + ) + hidden_states_permuted.append(hidden_state_permuted) + return hidden_states_permuted + + @property + def input_resolutions(self) -> List[int]: + return [layer.input_resolution for layer in self.model.encoder.layers] + + @property + def outputs_shapes(self) -> List[int]: + return [layer.dim for layer in self.model.encoder.layers] + + +class MaskFormerFPNConvLayer(nn.Sequential): + def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1): + """ + A basic module that executes conv - norm - in sequence used in MaskFormer. + + Args: + in_features (`int`): + The number of input features (channels). + out_features (`int`): + The number of outputs features (channels). + """ + super().__init__( + nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False), + nn.GroupNorm(32, out_features), + nn.ReLU(inplace=True), + ) + + +class MaskFormerFPNLayer(nn.Module): + def __init__(self, in_features: int, lateral_features: int): + """ + A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous + and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled. + + Args: + in_features (`int`): + The number of input features (channels). + lateral_features (`int`): + The number of lateral features (channels). + """ + super().__init__() + self.proj = nn.Sequential( + nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False), + nn.GroupNorm(32, in_features), + ) + + self.block = MaskFormerFPNConvLayer(in_features, in_features) + + def forward(self, down: Tensor, left: Tensor) -> Tensor: + left = self.proj(left) + down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest") + down += left + down = self.block(down) + return down + + +class MaskFormerFPNModel(nn.Module): + def __init__(self, in_features: int, lateral_widths: List[int], feature_size: int = 256): + """ + Feature Pyramid Network, given an input tensor and a set of feature map of different feature/spatial size, it + creates a list of feature maps with the same feature size. + + Args: + in_features (`int`): + The number of input features (channels). + lateral_widths (`List[int]`): + A list with the features (channels) size of each lateral connection. + feature_size (int, *optional*, defaults to 256): + The features (channels) of the resulting feature maps. + """ + super().__init__() + self.stem = MaskFormerFPNConvLayer(in_features, feature_size) + self.layers = nn.Sequential( + *[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]] + ) + + def forward(self, features: List[Tensor]) -> List[Tensor]: + fpn_features = [] + last_feature = features[-1] + other_features = features[:-1] + output = self.stem(last_feature) + for layer, left in zip(self.layers, other_features[::-1]): + output = layer(output, left) + fpn_features.append(output) + return fpn_features + + +class MaskFormerPixelDecoder(nn.Module): + def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs): + """ + Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic + Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's feature into a Feature Pyramid + Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`. + + Args: + feature_size (`int`, *optional*, defaults to 256): + The feature size (channel dimension) of the FPN feature maps. + mask_feature_size (`int`, *optional*, defaults to 256): + The features (channels) of the target masks size \\C_{\epsilon}\\ in the paper. + """ + super().__init__() + self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs) + self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1) + + def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput: + fpn_features: List[Tensor] = self.fpn(features) + # we use the last feature map + last_feature_projected = self.mask_projection(fpn_features[-1]) + return MaskFormerPixelDecoderOutput( + last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else () + ) + + +# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding +class MaskFormerSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one used by the Attention is all you + need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): + super().__init__() + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * torch.pi if scale is None else scale + + def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class MaskformerMLPPredictionHead(nn.Sequential): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): + """ + A classic Multi Layer Perceptron (MLP). + + Args: + input_dim (`int`): + The input dimensions. + hidden_dim (`int`): + The hidden dimensions. + output_dim (`int`): + The output dimensions. + num_layers (int, *optional*, defaults to 3): + The number of layers. + """ + in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) + out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] + + layers = [] + for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): + + layer = nn.Sequential( + nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity() + ) + layers.append(layer) + + super().__init__(*layers) + + +class MaskFormerPixelLevelModule(nn.Module): + def __init__(self, config: MaskFormerConfig): + """ + Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic + Segmentation](https://arxiv.org/abs/2107.06278). It runs the input image through a backbone and a pixel + decoder, generating an image feature map and pixel embeddings. + + Args: + config ([`MaskFormerConfig`]): + The configuration used to instantiate this model. + """ + super().__init__() + self.encoder = MaskFormerSwinTransformerBackbone(config.backbone_config) + self.decoder = MaskFormerPixelDecoder( + in_features=self.encoder.outputs_shapes[-1], + feature_size=config.fpn_feature_size, + mask_feature_size=config.mask_feature_size, + lateral_widths=self.encoder.outputs_shapes[:-1], + ) + + def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput: + features: List[Tensor] = self.encoder(pixel_values) + decoder_output: MaskFormerPixelDecoderOutput = self.decoder(features, output_hidden_states) + return MaskFormerPixelLevelModuleOutput( + # the last feature is actually the output from the last layer + encoder_last_hidden_state=features[-1], + decoder_last_hidden_state=decoder_output.last_hidden_state, + encoder_hidden_states=tuple(features) if output_hidden_states else (), + decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (), + ) + + +class MaskFormerTransformerModule(nn.Module): + """ + The MaskFormer's transformer module. + """ + + def __init__(self, in_features: int, config: MaskFormerConfig): + super().__init__() + hidden_size = config.decoder_config.hidden_size + should_project = in_features != hidden_size + self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True) + self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size) + self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None + self.decoder = DetrDecoder(config=config.decoder_config) + + def forward( + self, image_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False + ) -> DetrDecoderOutput: + if self.input_projection is not None: + image_features = self.input_projection(image_features) + position_embeddings = self.position_embedder(image_features) + # repeat the queries "q c -> b q c" + batch_size = image_features.shape[0] + queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1) + inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True) + + batch_size, num_channels, height, width = image_features.shape + # rearrange both image_features and position_embeddings "b c h w -> b (h w) c" + image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1) + position_embeddings = position_embeddings.view(batch_size, num_channels, height * width).permute(0, 2, 1) + + decoder_output: DetrDecoderOutput = self.decoder( + inputs_embeds=inputs_embeds, + attention_mask=None, + encoder_hidden_states=image_features, + encoder_attention_mask=None, + position_embeddings=position_embeddings, + query_position_embeddings=queries_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=None, + ) + return decoder_output + + +MASKFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`MaskFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MASKFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + [What are attention masks?](../glossary#attention-mask) + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of Detr's decoder attention layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~MaskFormerModelOutput`] instead of a plain tuple. +""" + + +class MaskFormerPreTrainedModel(PreTrainedModel): + config_class = MaskFormerConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + + def _init_weights(self, module: nn.Module): + xavier_std = self.config.init_xavier_std + std = self.config.init_std + if isinstance(module, MaskFormerTransformerModule): + if module.input_projection is not None: + nn.init.xavier_uniform_(module.input_projection.weight, gain=xavier_std) + nn.init.constant_(module.input_projection.bias, 0) + # FPN + elif isinstance(module, MaskFormerFPNModel): + nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std) + + elif isinstance(module, MaskFormerFPNLayer): + nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std) + + elif isinstance(module, MaskFormerFPNConvLayer): + nn.init.xavier_uniform_(module[0].weight, gain=xavier_std) + # The MLP head + elif isinstance(module, MaskformerMLPPredictionHead): + # I was not able to find the correct initializer in the original implementation + # we'll use xavier + for layer in module: + nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std) + nn.init.constant_(layer[0].bias, 0) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + # copied from DETR + if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MaskFormerSwinEncoder): + module.gradient_checkpointing = value + if isinstance(module, DetrDecoder): + module.gradient_checkpointing = value + + +@add_start_docstrings( + "The bare MaskFormer Model outputting raw hidden-states without any specific head on top.", + MASKFORMER_START_DOCSTRING, +) +class MaskFormerModel(MaskFormerPreTrainedModel): + def __init__(self, config: MaskFormerConfig): + super().__init__(config) + self.pixel_level_module = MaskFormerPixelLevelModule(config) + self.transformer_module = MaskFormerTransformerModule( + in_features=self.pixel_level_module.encoder.outputs_shapes[-1], config=config + ) + + self.post_init() + + @add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskFormerModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + ) + def forward( + self, + pixel_values: Tensor, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> MaskFormerModelOutput: + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, _, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) + + pixel_level_module_output: MaskFormerPixelLevelModuleOutput = self.pixel_level_module( + pixel_values, output_hidden_states + ) + image_features = pixel_level_module_output.encoder_last_hidden_state + pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state + + transformer_module_output: DetrDecoderOutput = self.transformer_module( + image_features, output_hidden_states, output_attentions + ) + queries = transformer_module_output.last_hidden_state + + encoder_hidden_states = pixel_level_module_output.encoder_hidden_states if output_hidden_states else () + pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states if output_hidden_states else () + transformer_decoder_hidden_states = transformer_module_output.hidden_states if output_hidden_states else () + + output = MaskFormerModelOutput( + encoder_last_hidden_state=image_features, + pixel_decoder_last_hidden_state=pixel_embeddings, + transformer_decoder_last_hidden_state=queries, + encoder_hidden_states=encoder_hidden_states, + pixel_decoder_hidden_states=pixel_decoder_hidden_states, + transformer_decoder_hidden_states=transformer_decoder_hidden_states, + hidden_states=encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states, + attentions=transformer_module_output.attentions, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + + return output + + +class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): + def __init__(self, config: MaskFormerConfig): + super().__init__(config) + self.model = MaskFormerModel(config) + hidden_size = config.decoder_config.hidden_size + # + 1 because we add the "null" class + self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1) + self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size) + + self.matcher = MaskFormerHungarianMatcher( + cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight + ) + + self.weight_dict: Dict[str, float] = { + "loss_cross_entropy": config.cross_entropy_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.criterion = MaskFormerLoss( + config.num_labels, + matcher=self.matcher, + weight_dict=self.weight_dict, + eos_coef=config.no_object_weight, + ) + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + auxiliary_logits: Dict[str, Tensor], + ) -> Dict[str, Tensor]: + loss_dict: Dict[str, Tensor] = self.criterion( + masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits + ) + # weight each loss by `self.weight_dict[]` + weighted_loss_dict: Dict[str, Tensor] = { + k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict + } + return weighted_loss_dict + + def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: + pixel_embeddings = outputs.pixel_decoder_last_hidden_state + # get the auxiliary predictions (one for each decoder's layer) + auxiliary_logits: List[str, Tensor] = [] + # This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list + if self.config.use_auxiliary_loss: + stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states) + classes = self.class_predictor(stacked_transformer_decoder_outputs) + class_queries_logits = classes[-1] + # get the masks + mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs) + # sum up over the channels for each embedding + binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings) + masks_queries_logits = binaries_masks[-1] + # go til [:-1] because the last one is always used + for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]): + auxiliary_logits.append( + {"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes} + ) + + else: + transformer_decoder_hidden_states = outputs.transformer_decoder_last_hidden_state + classes = self.class_predictor(transformer_decoder_hidden_states) + class_queries_logits = classes + # get the masks + mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states) + # sum up over the channels + masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) + + return class_queries_logits, masks_queries_logits, auxiliary_logits + + @add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskFormerForInstanceSegmentationOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Tensor, + mask_labels: Optional[Tensor] = None, + class_labels: Optional[Tensor] = None, + pixel_mask: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> MaskFormerForInstanceSegmentationOutput: + r""" + mask_labels (`torch.FloatTensor`, *optional*): + The target mask of shape `(num_classes, height, width)`. + class_labels (`torch.LongTensor`, *optional*): + The target labels of shape `(num_classes)`. + + Returns: + + Examples: + + ```python + >>> from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-base-ade") + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade") + >>> outputs = model(**inputs) + >>> # model predicts class_queries_logits of shape `(batch_size, num_queries)` + >>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)` + >>> class_queries_logits = outputs.class_queries_logits + >>> masks_queries_logits = outputs.masks_queries_logits + + >>> # you can pass them to feature_extractor for postprocessing + >>> output = feature_extractor.post_process_segmentation(outputs) + >>> output = feature_extractor.post_process_semantic_segmentation(outputs) + + >>> output = feature_extractor.post_process_panoptic_segmentation(outputs) + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs: MaskFormerModelOutput = self.model( + pixel_values, + pixel_mask, + output_hidden_states=output_hidden_states, + return_dict=True, + output_attentions=output_attentions, + ) + + loss, loss_dict, auxiliary_logits = None, None, None + + class_queries_logits, masks_queries_logits, auxiliary_logits = self.get_logits(outputs) + + if mask_labels is not None and class_labels is not None: + loss_dict: Dict[str, Tensor] = self.get_loss_dict( + masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits + ) + loss = self.get_loss(loss_dict) + + output = MaskFormerForInstanceSegmentationOutput( + loss=loss, + **outputs, + class_queries_logits=class_queries_logits, + masks_queries_logits=masks_queries_logits, + auxiliary_logits=auxiliary_logits, + ) + + if not return_dict: + output = tuple(v for v in output.values()) + if loss is not None: + output = ((loss)) + output + return output diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 71587cba3bc9f..335a4fe471631 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2401,6 +2401,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class MaskFormerForInstanceSegmentation(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MaskFormerModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MaskFormerPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MBartForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index e49088d8b88ee..e0e8ec0d3dbbe 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -80,6 +80,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class MaskFormerFeatureExtractor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class PerceiverFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/maskformer/__init__.py b/tests/maskformer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/maskformer/test_feature_extraction_maskformer.py b/tests/maskformer/test_feature_extraction_maskformer.py new file mode 100644 index 0000000000000..3dea5255855e7 --- /dev/null +++ b/tests/maskformer/test_feature_extraction_maskformer.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + + if is_vision_available(): + from transformers import MaskFormerFeatureExtractor + +if is_vision_available(): + from PIL import Image + + +class MaskFormerFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=32, + max_size=1333, # by setting max_size > max_resolution we're effectively not testing this :p + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.max_size = max_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.size_divisibility = 0 + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "max_size": self.max_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "size_divisibility": self.size_divisibility, + } + + def get_expected_values(self, image_inputs, batched=False): + """ + This function computes the expected height and width when providing images to MaskFormerFeatureExtractor, + assuming do_resize is set to True with a scalar size. + """ + if not batched: + image = image_inputs[0] + if isinstance(image, Image.Image): + w, h = image.size + else: + h, w = image.shape[1], image.shape[2] + if w < h: + expected_height = int(self.size * h / w) + expected_width = self.size + elif w > h: + expected_height = self.size + expected_width = int(self.size * w / h) + else: + expected_height = self.size + expected_width = self.size + + else: + expected_values = [] + for image in image_inputs: + expected_height, expected_width = self.get_expected_values([image]) + expected_values.append((expected_height, expected_width)) + expected_height = max(expected_values, key=lambda item: item[0])[0] + expected_width = max(expected_values, key=lambda item: item[1])[1] + + return expected_height, expected_width + + +@require_torch +@require_vision +class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = MaskFormerFeatureExtractor if (is_vision_available() and is_torch_available()) else None + + def setUp(self): + self.feature_extract_tester = MaskFormerFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "max_size")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True) + + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True) + + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs) + + self.assertEqual( + encoded_images.shape, + (1, self.feature_extract_tester.num_channels, expected_height, expected_width), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + + expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True) + + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + expected_height, + expected_width, + ), + ) + + def test_equivalence_pad_and_create_pixel_mask(self): + # Initialize feature_extractors + feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test whether the method "pad_and_return_pixel_mask" and calling the feature extractor return the same tensors + encoded_images_with_method = feature_extractor_1.encode_inputs(image_inputs, return_tensors="pt") + encoded_images = feature_extractor_2(image_inputs, return_tensors="pt") + + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4) + ) + self.assertTrue( + torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) + ) + + def comm_get_feature_extractor_inputs(self, with_annotations=False): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # prepare image and target + num_classes = 8 + batch_size = self.feature_extract_tester.batch_size + annotations = None + + if with_annotations: + annotations = [ + { + "masks": np.random.rand(num_classes, 384, 384).astype(np.float32), + "labels": (np.random.rand(num_classes) > 0.5).astype(np.int64), + } + for _ in range(batch_size) + ] + + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True) + + return inputs + + def test_with_size_divisibility(self): + size_divisibilities = [8, 16, 32] + weird_input_sizes = [(407, 802), (582, 1094)] + for size_divisibility in size_divisibilities: + feat_extract_dict = {**self.feat_extract_dict, **{"size_divisibility": size_divisibility}} + feature_extractor = self.feature_extraction_class(**feat_extract_dict) + for weird_input_size in weird_input_sizes: + inputs = feature_extractor([np.ones((3, *weird_input_size))], return_tensors="pt") + pixel_values = inputs["pixel_values"] + # check if divisible + self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0) + self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0) + + def test_call_with_numpy_annotations(self): + num_classes = 8 + batch_size = self.feature_extract_tester.batch_size + + inputs = self.comm_get_feature_extractor_inputs(with_annotations=True) + + # check the batch_size + for el in inputs.values(): + self.assertEqual(el.shape[0], batch_size) + + pixel_values = inputs["pixel_values"] + mask_labels = inputs["mask_labels"] + class_labels = inputs["class_labels"] + + self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2]) + self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1]) + self.assertEqual(mask_labels.shape[1], class_labels.shape[1]) + self.assertEqual(mask_labels.shape[1], num_classes) diff --git a/tests/maskformer/test_modeling_maskformer.py b/tests/maskformer/test_modeling_maskformer.py new file mode 100644 index 0000000000000..23c9aff46618c --- /dev/null +++ b/tests/maskformer/test_modeling_maskformer.py @@ -0,0 +1,405 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch MaskFormer model. """ + +import inspect +import unittest + +import numpy as np + +from tests.test_modeling_common import floats_tensor +from transformers import MaskFormerConfig, is_torch_available, is_vision_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ModelTesterMixin + + +if is_torch_available(): + import torch + + from transformers import MaskFormerForInstanceSegmentation, MaskFormerModel + + if is_vision_available(): + from transformers import MaskFormerFeatureExtractor + +if is_vision_available(): + from PIL import Image + + +class MaskFormerModelTester: + def __init__( + self, + parent, + batch_size=2, + is_training=True, + use_auxiliary_loss=False, + num_queries=100, + num_channels=3, + min_size=384, + max_size=640, + num_labels=150, + mask_feature_size=256, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_auxiliary_loss = use_auxiliary_loss + self.num_queries = num_queries + self.num_channels = num_channels + self.min_size = min_size + self.max_size = max_size + self.num_labels = num_labels + self.mask_feature_size = mask_feature_size + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]) + + pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device) + + mask_labels = ( + torch.rand([self.batch_size, self.num_labels, self.min_size, self.max_size], device=torch_device) > 0.5 + ).float() + class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long() + + config = self.get_config() + return config, pixel_values, pixel_mask, mask_labels, class_labels + + def get_config(self): + return MaskFormerConfig( + num_queries=self.num_queries, + num_channels=self.num_channels, + num_labels=self.num_labels, + mask_feature_size=self.mask_feature_size, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, _, _ = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + return config, inputs_dict + + def check_output_hidden_state(self, output, config): + encoder_hidden_states = output.encoder_hidden_states + pixel_decoder_hidden_states = output.pixel_decoder_hidden_states + transformer_decoder_hidden_states = output.transformer_decoder_hidden_states + + self.parent.assertTrue(len(encoder_hidden_states), len(config.backbone_config.depths)) + self.parent.assertTrue(len(pixel_decoder_hidden_states), len(config.backbone_config.depths)) + self.parent.assertTrue(len(transformer_decoder_hidden_states), config.decoder_config.decoder_layers) + + def create_and_check_maskformer_model(self, config, pixel_values, pixel_mask, output_hidden_states=False): + with torch.no_grad(): + model = MaskFormerModel(config=config) + model.to(torch_device) + model.eval() + + output = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + output = model(pixel_values, output_hidden_states=True) + # the correct shape of output.transformer_decoder_hidden_states ensure the correcteness of the + # encoder and pixel decoder + self.parent.assertEqual( + output.transformer_decoder_last_hidden_state.shape, + (self.batch_size, self.num_queries, self.mask_feature_size), + ) + # let's ensure the other two hidden state exists + self.parent.assertTrue(output.pixel_decoder_last_hidden_state is not None) + self.parent.assertTrue(output.encoder_last_hidden_state is not None) + + if output_hidden_states: + self.check_output_hidden_state(output, config) + + def create_and_check_maskformer_instance_segmentation_head_model( + self, config, pixel_values, pixel_mask, mask_labels, class_labels + ): + model = MaskFormerForInstanceSegmentation(config=config) + model.to(torch_device) + model.eval() + + def comm_check_on_output(result): + # let's still check that all the required stuff is there + self.parent.assertTrue(result.transformer_decoder_hidden_states is not None) + self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None) + self.parent.assertTrue(result.encoder_last_hidden_state is not None) + # okay, now we need to check the logits shape + # due to the encoder compression, masks have a //4 spatial size + self.parent.assertEqual( + result.masks_queries_logits.shape, + (self.batch_size, self.num_queries, self.min_size // 4, self.max_size // 4), + ) + # + 1 for null class + self.parent.assertEqual( + result.class_queries_logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1) + ) + + with torch.no_grad(): + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + comm_check_on_output(result) + + result = model( + pixel_values=pixel_values, pixel_mask=pixel_mask, mask_labels=mask_labels, class_labels=class_labels + ) + + comm_check_on_output(result) + + self.parent.assertTrue(result.loss is not None) + self.parent.assertEqual(result.loss.shape, torch.Size([1])) + + +@require_torch +class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (MaskFormerModel, MaskFormerForInstanceSegmentation) if is_torch_available() else () + + is_encoder_decoder = False + test_torchscript = False + test_pruning = False + test_head_masking = False + test_missing_keys = False + + def setUp(self): + self.model_tester = MaskFormerModelTester(self) + self.config_tester = ConfigTester(self, config_class=MaskFormerConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_maskformer_model(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=False) + + def test_maskformer_instance_segmentation_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_maskformer_instance_segmentation_head_model(*config_and_inputs) + + @unittest.skip(reason="MaskFormer does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="MaskFormer does not have a get_input_embeddings method") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="MaskFormer is not a generative model") + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="MaskFormer does not use token embeddings") + def test_resize_tokens_embeddings(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @slow + def test_model_from_pretrained(self): + for model_name in ["facebook/maskformer-swin-small-coco"]: + model = MaskFormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @slow + def test_model_with_labels(self): + inputs = { + "pixel_values": torch.randn((2, 3, 384, 384)), + "mask_labels": torch.randn((2, 10, 384, 384)), + "class_labels": torch.zeros(2, 10).long(), + } + + model = MaskFormerForInstanceSegmentation(MaskFormerConfig()) + outputs = model(**inputs) + self.assertTrue(outputs.loss is not None) + + def test_hidden_states_output(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.create_and_check_maskformer_model(config, **inputs, output_hidden_states=True) + + def test_attention_outputs(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + outputs = model(**inputs, output_attentions=True) + self.assertTrue(outputs.attentions is not None) + + def test_training(self): + if not self.model_tester.is_training: + return + # only MaskFormerForInstanceSegmentation has the loss + model_class = self.all_model_classes[1] + config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs() + + model = model_class(config) + model.to(torch_device) + model.train() + + loss = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels).loss + loss.backward() + + def test_retain_grad_hidden_states_attentions(self): + # only MaskFormerForInstanceSegmentation has the loss + model_class = self.all_model_classes[1] + config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs() + config.output_hidden_states = True + config.output_attentions = True + + model = model_class(config) + model.to(torch_device) + model.train() + + outputs = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels) + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_hidden_states.retain_grad() + + pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states[0] + pixel_decoder_hidden_states.retain_grad() + # we requires_grad=True in inputs_embeds (line 2152), the original implementation don't + transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states[0] + transformer_decoder_hidden_states.retain_grad() + + attentions = outputs.attentions[0] + attentions.retain_grad() + + outputs.loss.backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(pixel_decoder_hidden_states.grad) + self.assertIsNotNone(transformer_decoder_hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + +TOLERANCE = 1e-4 + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_vision +@slow +class MaskFormerModelIntegrationTest(unittest.TestCase): + @cached_property + def model_checkpoints(self): + return "facebook/maskformer-swin-small-coco" + + @cached_property + def default_feature_extractor(self): + return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None + + def test_inference_no_head(self): + model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device) + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(image, return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size is divisible by 32 + self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0) + # check size + self.assertEqual(inputs_shape, (1, 3, 800, 1088)) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_slice_hidden_state = torch.tensor( + [[-0.0482, 0.9228, 0.4951], [-0.2547, 0.8017, 0.8527], [-0.0069, 0.3385, -0.0089]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[-0.8422, -0.8434, -0.9718], [-1.0144, -0.5565, -0.4195], [-1.0038, -0.4484, -0.1961]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + expected_slice_hidden_state = torch.tensor( + [[0.2852, -0.0159, 0.9735], [0.6254, 0.1858, 0.8529], [-0.0680, -0.4116, 1.8413]] + ).to(torch_device) + self.assertTrue( + torch.allclose( + outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE + ) + ) + + def test_inference_instance_segmentation_head(self): + model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(image, return_tensors="pt").to(torch_device) + inputs_shape = inputs["pixel_values"].shape + # check size is divisible by 32 + self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0) + # check size + self.assertEqual(inputs_shape, (1, 3, 800, 1088)) + + with torch.no_grad(): + outputs = model(**inputs) + # masks_queries_logits + masks_queries_logits = outputs.masks_queries_logits + self.assertEqual( + masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) + ) + expected_slice = torch.tensor( + [[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]] + ) + self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) + # class_queries_logits + class_queries_logits = outputs.class_queries_logits + self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1)) + expected_slice = torch.tensor( + [ + [1.6512e00, -5.2572e00, -3.3519e00], + [3.6169e-02, -5.9025e00, -2.9313e00], + [1.0766e-04, -7.7630e00, -5.1263e00], + ] + ) + self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) + + def test_with_annotations_and_loss(self): + model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + feature_extractor = self.default_feature_extractor + + inputs = feature_extractor( + [np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))], + annotations=[ + {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, + {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, + ], + return_tensors="pt", + ) + + with torch.no_grad(): + outputs = model(**inputs) + + self.assertTrue(outputs.loss is not None) diff --git a/utils/check_repo.py b/utils/check_repo.py index ee7d3eec065fb..76017d5a27f57 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -169,6 +169,7 @@ "VisualBertForMultipleChoice", "TFWav2Vec2ForCTC", "TFHubertForCTC", + "MaskFormerForInstanceSegmentation", ] # Update this list for models that have multiple model types for the same