diff --git a/README.md b/README.md index 788b7dee9f95f..43e0d7acb4f52 100644 --- a/README.md +++ b/README.md @@ -332,6 +332,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. diff --git a/README_ko.md b/README_ko.md index ff953765ca5d5..91bb99cf707bb 100644 --- a/README_ko.md +++ b/README_ko.md @@ -288,6 +288,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. diff --git a/README_zh-hans.md b/README_zh-hans.md index 998d90e378a42..70b711711cc95 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -312,6 +312,7 @@ conda install -c huggingface transformers 1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (来自 Meta) 伴随论文 [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) 由 the NLLB team 发布。 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (来自 the University of Wisconsin - Madison) 伴随论文 [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) 由 Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh 发布。 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。 +1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (来自 Deepmind) 伴随论文 [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 由 Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira 发布。 1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (来自 VinAI Research) 伴随论文 [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) 由 Dat Quoc Nguyen and Anh Tuan Nguyen 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 66919dfe9cd21..6e3356d1ad7fc 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -324,6 +324,7 @@ conda install -c huggingface transformers 1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2dbce02d4bcc5..72bc4f2093cf8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -326,6 +326,8 @@ title: Nyströmformer - local: model_doc/opt title: OPT + - local: model_doc/owlvit + title: OWL-ViT - local: model_doc/pegasus title: Pegasus - local: model_doc/perceiver diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index b8cd08eb04735..c1e174c3f537c 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -130,6 +130,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[NLLB](model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh. 1. **[OPT](master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. 1. **[Perceiver IO](model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira. 1. **[PhoBERT](model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen. @@ -263,6 +264,7 @@ Flax), PyTorch, and/or TensorFlow. | OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ | | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ | | OPT | ❌ | ❌ | ✅ | ✅ | ✅ | +| OWL-ViT | ❌ | ❌ | ✅ | ❌ | ❌ | | Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ | | Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ | | PLBart | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/owlvit.mdx b/docs/source/en/model_doc/owlvit.mdx new file mode 100644 index 0000000000000..84747d0a6d260 --- /dev/null +++ b/docs/source/en/model_doc/owlvit.mdx @@ -0,0 +1,101 @@ + + +# OWL-ViT + +## Overview + +The OWL-ViT (short for Vision Transformer for Open-World Localization) was proposed in [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. OWL-ViT is an open-vocabulary object detection network trained on a variety of (image, text) pairs. It can be used to query an image with one or multiple text queries to search for and detect target objects described in text. + +The abstract from the paper is the following: + +*Combining simple architectures with large-scale pre-training has led to massive improvements in image classification. For object detection, pre-training and scaling approaches are less well established, especially in the long-tailed and open-vocabulary setting, where training data is relatively scarce. In this paper, we propose a strong recipe for transferring image-text models to open-vocabulary object detection. We use a standard Vision Transformer architecture with minimal modifications, contrastive image-text pre-training, and end-to-end detection fine-tuning. Our analysis of the scaling properties of this setup shows that increasing image-level pre-training and model size yield consistent improvements on the downstream detection task. We provide the adaptation strategies and regularizations needed to attain very strong performance on zero-shot text-conditioned and one-shot image-conditioned object detection. Code and models are available on GitHub.* + +## Usage + +OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CLIP](clip) as its multi-modal backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection. + +[`OwlViTFeatureExtractor`] can be used to resize (or rescale) and normalize images for the model and [`CLIPTokenizer`] is used to encode the text. [`OwlViTProcessor`] wraps [`OwlViTFeatureExtractor`] and [`CLIPTokenizer`] into a single instance to both encode the text and prepare the images. The following example shows how to perform object detection using [`OwlViTProcessor`] and [`OwlViTForObjectDetection`]. + + +```python +>>> import requests +>>> from PIL import Image +>>> import torch + +>>> from transformers import OwlViTProcessor, OwlViTForObjectDetection + +>>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") +>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) + +>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + +>>> outputs = model(**inputs) +>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries] +>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape [batch_size, num_patches, 4] + +>>> batch_size = boxes.shape[0] +>>> for i in range(batch_size): # Loop over sets of images and text queries +... boxes = outputs["pred_boxes"][i] +... logits = torch.max(outputs["logits"][i], dim=-1) +... scores = torch.sigmoid(logits.values) +... labels = logits.indices +``` + +This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit). + +## OwlViTConfig + +[[autodoc]] OwlViTConfig + - from_text_vision_configs + +## OwlViTTextConfig + +[[autodoc]] OwlViTTextConfig + +## OwlViTVisionConfig + +[[autodoc]] OwlViTVisionConfig + +## OwlViTFeatureExtractor + +[[autodoc]] OwlViTFeatureExtractor + - __call__ + +## OwlViTProcessor + +[[autodoc]] OwlViTProcessor + +## OwlViTModel + +[[autodoc]] OwlViTModel + - forward + - get_text_features + - get_image_features + +## OwlViTTextModel + +[[autodoc]] OwlViTTextModel + - forward + +## OwlViTVisionModel + +[[autodoc]] OwlViTVisionModel + - forward + +## OwlViTForObjectDetection + +[[autodoc]] OwlViTForObjectDetection + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a3af2dc71254c..700636bb866fb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -273,6 +273,13 @@ ], "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], "models.opt": ["OPTConfig"], + "models.owlvit": [ + "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "OwlViTConfig", + "OwlViTProcessor", + "OwlViTTextConfig", + "OwlViTVisionConfig", + ], "models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"], "models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"], "models.phobert": ["PhobertTokenizer"], @@ -641,6 +648,7 @@ _import_structure["models.levit"].append("LevitFeatureExtractor") _import_structure["models.maskformer"].append("MaskFormerFeatureExtractor") _import_structure["models.mobilevit"].append("MobileViTFeatureExtractor") + _import_structure["models.owlvit"].append("OwlViTFeatureExtractor") _import_structure["models.perceiver"].append("PerceiverFeatureExtractor") _import_structure["models.poolformer"].append("PoolFormerFeatureExtractor") _import_structure["models.segformer"].append("SegformerFeatureExtractor") @@ -1507,6 +1515,16 @@ "OPTForSequenceClassification", ] ) + _import_structure["models.owlvit"].extend( + [ + "OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "OwlViTModel", + "OwlViTPreTrainedModel", + "OwlViTTextModel", + "OwlViTVisionModel", + "OwlViTForObjectDetection", + ] + ) _import_structure["models.pegasus"].extend( ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] ) @@ -3012,6 +3030,13 @@ from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer from .models.opt import OPTConfig + from .models.owlvit import ( + OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, + OwlViTConfig, + OwlViTProcessor, + OwlViTTextConfig, + OwlViTVisionConfig, + ) from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer from .models.phobert import PhobertTokenizer @@ -3328,6 +3353,7 @@ from .models.levit import LevitFeatureExtractor from .models.maskformer import MaskFormerFeatureExtractor from .models.mobilevit import MobileViTFeatureExtractor + from .models.owlvit import OwlViTFeatureExtractor from .models.perceiver import PerceiverFeatureExtractor from .models.poolformer import PoolFormerFeatureExtractor from .models.segformer import SegformerFeatureExtractor @@ -4044,6 +4070,14 @@ OPTModel, OPTPreTrainedModel, ) + from .models.owlvit import ( + OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST, + OwlViTForObjectDetection, + OwlViTModel, + OwlViTPreTrainedModel, + OwlViTTextModel, + OwlViTVisionModel, + ) from .models.pegasus import ( PegasusForCausalLM, PegasusForConditionalGeneration, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b6fe28a33006e..663db063d68e7 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -101,6 +101,7 @@ nystromformer, openai, opt, + owlvit, pegasus, perceiver, phobert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 51e8c15c11a1c..b8d5bbf223918 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -98,6 +98,7 @@ ("nystromformer", "NystromformerConfig"), ("openai-gpt", "OpenAIGPTConfig"), ("opt", "OPTConfig"), + ("owlvit", "OwlViTConfig"), ("pegasus", "PegasusConfig"), ("perceiver", "PerceiverConfig"), ("plbart", "PLBartConfig"), @@ -216,6 +217,7 @@ ("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -346,6 +348,7 @@ ("nystromformer", "Nyströmformer"), ("openai-gpt", "OpenAI GPT"), ("opt", "OPT"), + ("owlvit", "OWL-ViT"), ("pegasus", "Pegasus"), ("perceiver", "Perceiver"), ("phobert", "PhoBERT"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index fa39d2ea11179..57def4b27396a 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -58,6 +58,7 @@ ("maskformer", "MaskFormerFeatureExtractor"), ("mctct", "MCTCTFeatureExtractor"), ("mobilevit", "MobileViTFeatureExtractor"), + ("owlvit", "OwlViTFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), ("poolformer", "PoolFormerFeatureExtractor"), ("regnet", "ConvNextFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1c9cf86fbf854..d443d71ca5f72 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -98,6 +98,7 @@ ("nystromformer", "NystromformerModel"), ("openai-gpt", "OpenAIGPTModel"), ("opt", "OPTModel"), + ("owlvit", "OwlViTModel"), ("pegasus", "PegasusModel"), ("perceiver", "PerceiverModel"), ("plbart", "PLBartModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 9bb329852b8f9..d81dd19ea23dd 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -43,6 +43,7 @@ ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), ("layoutxlm", "LayoutXLMProcessor"), + ("owlvit", "OwlViTProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), ("speech_to_text", "Speech2TextProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c614d7e5c8f19..7a2dc2941fdd0 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -193,6 +193,7 @@ ), ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", None)), + ("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ( "pegasus", ( diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py index 121fc6e65af5e..3bb22b74a4c77 100644 --- a/src/transformers/models/clip/configuration_clip.py +++ b/src/transformers/models/clip/configuration_clip.py @@ -199,6 +199,7 @@ def __init__( intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, + num_channels=3, image_size=224, patch_size=32, hidden_act="quick_gelu", @@ -216,6 +217,7 @@ def __init__( self.dropout = dropout self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_range = initializer_range diff --git a/src/transformers/models/owlvit/__init__.py b/src/transformers/models/owlvit/__init__.py new file mode 100644 index 0000000000000..7fbf47124cfc0 --- /dev/null +++ b/src/transformers/models/owlvit/__init__.py @@ -0,0 +1,100 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, + is_vision_available, +) + + +_import_structure = { + "configuration_owlvit": [ + "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "OwlViTConfig", + "OwlViTTextConfig", + "OwlViTVisionConfig", + ], + "processing_owlvit": ["OwlViTProcessor"], +} + + +try: + if not is_vision_available() or not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_owlvit"] = [ + "OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "OwlViTModel", + "OwlViTPreTrainedModel", + "OwlViTTextModel", + "OwlViTVisionModel", + "OwlViTForObjectDetection", + ] + +if TYPE_CHECKING: + from .configuration_owlvit import ( + OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, + OwlViTConfig, + OwlViTTextConfig, + OwlViTVisionConfig, + ) + from .processing_owlvit import OwlViTProcessor + + try: + if not is_vision_available() or not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_owlvit import OwlViTFeatureExtractor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_owlvit import ( + OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST, + OwlViTForObjectDetection, + OwlViTModel, + OwlViTPreTrainedModel, + OwlViTTextModel, + OwlViTVisionModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/owlvit/configuration_owlvit.py b/src/transformers/models/owlvit/configuration_owlvit.py new file mode 100644 index 0000000000000..85ffdbadbeff3 --- /dev/null +++ b/src/transformers/models/owlvit/configuration_owlvit.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OWL-ViT model configuration""" + +import copy +import os +from typing import Dict, Union + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/owlvit-base-patch32": "https://huggingface.co/google/owlvit-base-patch32/resolve/main/config.json", + "google/owlvit-base-patch16": "https://huggingface.co/google/owlvit-base-patch16/resolve/main/config.json", + "google/owlvit-large-patch14": "https://huggingface.co/google/owlvit-large-patch14/resolve/main/config.json", +} + + +class OwlViTTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTTextModel`]. It is used to instantiate an + OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OwlViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49408): + Vocabulary size of the OWL-ViT text model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`OwlViTTextModel`]. + hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 16): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, + defaults to 1e-5): The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import OwlViTTextConfig, OwlViTTextModel + + >>> # Initializing a OwlViTTextModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTTextConfig() + + >>> # Initializing a OwlViTTextConfig from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "owlvit_text_model" + + def __init__( + self, + vocab_size=49408, + hidden_size=512, + intermediate_size=2048, + num_hidden_layers=12, + num_attention_heads=8, + max_position_embeddings=16, + hidden_act="quick_gelu", + layer_norm_eps=0.00001, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=0, + bos_token_id=49406, + eos_token_id=49407, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from OwlViTConfig + if config_dict.get("model_type") == "owlvit": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class OwlViTVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`OwlViTVisionModel`]. It is used to instantiate + an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the OWL-ViT + [google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 768): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*, + defaults to 1e-5): The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float``, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + Example: + + ```python + >>> from transformers import OwlViTVisionConfig, OwlViTVisionModel + + >>> # Initializing a OwlViTVisionModel with google/owlvit-base-patch32 style configuration + >>> configuration = OwlViTVisionConfig() + + >>> # Initializing a OwlViTVisionModel model from the google/owlvit-base-patch32 style configuration + >>> model = OwlViTVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "owlvit_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + image_size=768, + patch_size=32, + hidden_act="quick_gelu", + layer_norm_eps=0.00001, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from OwlViTConfig + if config_dict.get("model_type") == "owlvit": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class OwlViTConfig(PretrainedConfig): + r""" + [`OwlViTConfig`] is the configuration class to store the configuration of an [`OwlViTModel`]. It is used to + instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config_dict (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTTextConfig`]. + vision_config_dict (`dict`, *optional*): + Dictionary of configuration options used to initialize [`OwlViTVisionConfig`]. + projection_dim (`int`, *optional*, defaults to 512): + Dimensionality of text and vision projection layers. + logit_scale_init_value (`float`, *optional*, defaults to 2.6592): + The inital value of the *logit_scale* parameter. Default is used as per the original OWL-ViT + implementation. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "owlvit" + is_composition = True + + def __init__( + self, + text_config=None, + vision_config=None, + projection_dim=512, + logit_scale_init_value=2.6592, + return_dict=True, + **kwargs + ): + super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) + + if text_config is None: + text_config = {} + logger.info("text_config_dict is None. Initializing the OwlViTTextConfig with default values.") + + if vision_config is None: + vision_config = {} + logger.info("vision_config_dict is None. initializing the OwlViTVisionConfig with default values.") + + self.text_config = OwlViTTextConfig(**text_config) + self.vision_config = OwlViTVisionConfig(**vision_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + self.return_dict = return_dict + self.initializer_factor = 1.0 + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def from_text_vision_configs(cls, text_config: Dict, vision_config: Dict, **kwargs): + r""" + Instantiate a [`OwlViTConfig`] (or a derived class) from owlvit text model configuration and owlvit vision + model configuration. + + Returns: + [`OwlViTConfig`]: An instance of a configuration object + """ + config_dict = {} + config_dict["text_config"] = text_config + config_dict["vision_config"] = vision_config + + return cls.from_dict(config_dict, **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["text_config"] = self.text_config.to_dict() + output["vision_config"] = self.vision_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py new file mode 100644 index 0000000000000..dde57c168adea --- /dev/null +++ b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py @@ -0,0 +1,407 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OWL-ViT checkpoints from the original repository. URL: +https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit""" + +import argparse +import collections + +import torch +import torch.nn as nn + +import jax +import jax.numpy as jnp +from clip.model import CLIP +from flax.training import checkpoints +from huggingface_hub import Repository +from transformers import ( + CLIPTokenizer, + OwlViTConfig, + OwlViTFeatureExtractor, + OwlViTForObjectDetection, + OwlViTModel, + OwlViTProcessor, +) + + +CONFIGS = { + "vit_b32": dict( + embed_dim=512, + image_resolution=768, + context_length=16, + vocab_size=49408, + vision_layers=12, + vision_width=768, + vision_patch_size=32, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + ), + "vit_b16": dict( + embed_dim=512, + image_resolution=768, + context_length=16, + vocab_size=49408, + vision_layers=12, + vision_width=768, + vision_patch_size=16, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + ), + "vit_l14": dict( + embed_dim=768, + image_resolution=840, + context_length=16, + vocab_size=49408, + vision_layers=24, + vision_width=1024, + vision_patch_size=14, + transformer_width=768, + transformer_heads=12, + transformer_layers=12, + ), +} + + +def flatten_nested_dict(params, parent_key="", sep="/"): + items = [] + + for k, v in params.items(): + new_key = parent_key + sep + k if parent_key else k + + if isinstance(v, collections.MutableMapping): + items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def to_f32(params): + return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) + + +def copy_attn_layer(hf_attn_layer, pt_attn_layer): + q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) + q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) + + out_proj_weights = pt_attn_layer.out_proj.weight + out_proj_bias = pt_attn_layer.out_proj.bias + + hf_attn_layer.q_proj.weight.data = q_proj + hf_attn_layer.q_proj.bias.data = q_proj_bias + + hf_attn_layer.k_proj.weight.data = k_proj + hf_attn_layer.k_proj.bias.data = k_proj_bias + + hf_attn_layer.v_proj.weight.data = v_proj + hf_attn_layer.v_proj.bias.data = v_proj_bias + + hf_attn_layer.out_proj.weight = out_proj_weights + hf_attn_layer.out_proj.bias = out_proj_bias + + +def copy_mlp(hf_mlp, pt_mlp): + copy_linear(hf_mlp.fc1, pt_mlp.c_fc) + copy_linear(hf_mlp.fc2, pt_mlp.c_proj) + + +def copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + +def copy_layer(hf_layer, pt_layer): + # copy layer norms + copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) + copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) + + # copy MLP + copy_mlp(hf_layer.mlp, pt_layer.mlp) + + # copy attn + copy_attn_layer(hf_layer.self_attn, pt_layer.attn) + + +def copy_layers(hf_layers, pt_layers): + for hf_layer, pt_layer in zip(hf_layers, pt_layers): + copy_layer(hf_layer, pt_layer) + + +def copy_encoder(hf_encoder, pt_model): + # copy embeds + hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight + hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding + + # copy layer norm + copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) + + # copy hidden layers + copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) + + +def copy_text_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.text_projection.weight.data = pt_model.text_projection.data.T + + # copy text encoder + copy_encoder(hf_model.text_model, pt_model) + + +def copy_vision_model_and_projection(hf_model, pt_model): + # copy projection + hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T + + # copy layer norms + copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre) + copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) + + # copy embeds + hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data + hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding + hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data + + # copy encoder + copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) + + +def copy_class_merge_token(hf_model, flax_params): + flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"]) + + weight = torch.from_numpy(flax_class_token_params["scale"]) + bias = torch.from_numpy(flax_class_token_params["bias"]) + hf_model.layer_norm.weight = nn.Parameter(weight) + hf_model.layer_norm.bias = nn.Parameter(bias) + + +def copy_class_box_heads(hf_model, flax_params): + pt_params = hf_model.state_dict() + new_params = {} + + # Rename class prediction head flax params to pytorch HF + flax_class_params = flatten_nested_dict(flax_params["class_head"]) + + for flax_key, v in flax_class_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("Dense_0", "dense0") + torch_key = "class_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Rename box prediction box flax params to pytorch HF + flax_box_params = flatten_nested_dict(flax_params["obj_box_head"]) + + for flax_key, v in flax_box_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace(".kernel", ".weight") + torch_key = torch_key.replace("_", "").lower() + torch_key = "box_head." + torch_key + + if "weight" in torch_key and v.ndim == 2: + v = v.T + + new_params[torch_key] = nn.Parameter(torch.from_numpy(v)) + + # Copy flax params to PyTorch params + for name, param in new_params.items(): + if name in pt_params.keys(): + pt_params[name].copy_(param) + + +def copy_flax_attn_params(hf_backbone, flax_attn_params): + for k, v in flax_attn_params.items(): + if k.startswith("transformer"): + torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers") + else: + torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers") + + torch_key = torch_key.replace("attn", "self_attn") + torch_key = torch_key.replace("key", "k_proj") + torch_key = torch_key.replace("value", "v_proj") + torch_key = torch_key.replace("query", "q_proj") + torch_key = torch_key.replace("out", "out_proj") + + if "bias" in torch_key and v.ndim == 2: + shape = v.shape[0] * v.shape[1] + v = v.reshape(shape) + + if "weight" in torch_key and "out" in torch_key: + shape = (v.shape[0] * v.shape[1], v.shape[2]) + v = v.reshape(shape).T + + if "weight" in torch_key and "out" not in torch_key: + shape = (v.shape[0], v.shape[1] * v.shape[2]) + v = v.reshape(shape).T + + # Copy flax CLIP attn params to HF PyTorch params + v = torch.from_numpy(v) + hf_backbone.state_dict()[torch_key].copy_(v) + + +def _convert_attn_layers(params): + new_params = {} + processed_attn_layers = [] + + for k, v in params.items(): + if "attn." in k: + base = k[: k.rindex("attn.") + 5] + if base in processed_attn_layers: + continue + + processed_attn_layers.append(base) + dim = params[base + "out.weight"].shape[-1] + new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T + new_params[base + "out_proj.bias"] = params[base + "out.bias"] + else: + new_params[k] = v + return new_params + + +def convert_clip_backbone(flax_params, torch_config): + torch_model = CLIP(**torch_config) + torch_model.eval() + torch_clip_params = torch_model.state_dict() + + flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"]) + new_torch_params = {} + + for flax_key, v in flax_clip_params.items(): + torch_key = flax_key.replace("/", ".") + torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel") + + if ( + torch_key.startswith("text.transformer") + or torch_key.startswith("text.text_projection") + or torch_key.startswith("text.ln_final") + or torch_key.startswith("text.positional_embedding") + ): + torch_key = torch_key[5:] + + torch_key = torch_key.replace("text_projection.kernel", "text_projection") + torch_key = torch_key.replace("visual.proj.kernel", "visual.proj") + torch_key = torch_key.replace(".scale", ".weight") + torch_key = torch_key.replace(".kernel", ".weight") + + if "conv" in torch_key or "downsample.0.weight" in torch_key: + v = v.transpose(3, 2, 0, 1) + + elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key: + # Fully connected layers are transposed, embeddings are not + v = v.T + + new_torch_params[torch_key] = v + + attn_params = _convert_attn_layers(new_torch_params) + new_torch_params.update(attn_params) + attn_params = {} + + # Copy flax CLIP backbone params to PyTorch params + for name, param in new_torch_params.items(): + if name in torch_clip_params.keys(): + + new_param = torch.from_numpy(new_torch_params[name]) + torch_clip_params[name].copy_(new_param) + else: + attn_params[name] = param + + return torch_clip_params, torch_model, attn_params + + +@torch.no_grad() +def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None): + """ + Copy/paste/tweak model's weights to transformers design. + """ + repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}") + repo.git_pull() + + if config_path is not None: + config = OwlViTConfig.from_pretrained(config_path) + else: + config = OwlViTConfig() + + hf_backbone = OwlViTModel(config).eval() + hf_model = OwlViTForObjectDetection(config).eval() + + copy_text_model_and_projection(hf_backbone, pt_backbone) + copy_vision_model_and_projection(hf_backbone, pt_backbone) + hf_backbone.logit_scale = pt_backbone.logit_scale + copy_flax_attn_params(hf_backbone, attn_params) + + hf_model.owlvit = hf_backbone + copy_class_merge_token(hf_model, flax_params) + copy_class_box_heads(hf_model, flax_params) + + # Save HF model + hf_model.save_pretrained(repo.local_dir) + + # Initialize feature extractor + feature_extractor = OwlViTFeatureExtractor( + size=config.vision_config.image_size, crop_size=config.vision_config.image_size + ) + # Initialize tokenizer + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16) + + # Initialize processor + processor = OwlViTProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) + feature_extractor.save_pretrained(repo.local_dir) + processor.save_pretrained(repo.local_dir) + + repo.git_add() + repo.git_commit("Upload model and processor") + repo.git_push() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--owlvit_version", + default=None, + type=str, + required=True, + help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].", + ) + parser.add_argument( + "--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint." + ) + parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.") + parser.add_argument( + "--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + + # Initialize PyToch clip model + model_name = args.owlvit_version + if model_name == "clip_b16": + torch_config = CONFIGS["vit_b16"] + elif model_name == "clip_b32": + torch_config = CONFIGS["vit_b32"] + elif model_name == "clip_l14": + torch_config = CONFIGS["vit_l14"] + + # Load from checkpoint and convert params to float-32 + variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"] + flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) + del variables + + # Convert CLIP backbone + pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config) + + convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config) diff --git a/src/transformers/models/owlvit/feature_extraction_owlvit.py b/src/transformers/models/owlvit/feature_extraction_owlvit.py new file mode 100644 index 0000000000000..8e0a14208551f --- /dev/null +++ b/src/transformers/models/owlvit/feature_extraction_owlvit.py @@ -0,0 +1,210 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for OwlViT.""" + +from typing import List, Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import TensorType, is_torch_available, logging + + +if is_torch_available(): + import torch + from torch import nn + +logger = logging.get_logger(__name__) + + +def center_to_corners_format(x): + """ + Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format + (left, top, right, bottom). + """ + x_center, y_center, width, height = x.unbind(-1) + boxes = [(x_center - 0.5 * width), (y_center - 0.5 * height), (x_center + 0.5 * width), (y_center + 0.5 * height)] + return torch.stack(boxes, dim=-1) + + +class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs an OWL-ViT feature extractor. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the shorter edge of the input to a certain `size`. + size (`int`, *optional*, defaults to 768): + Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the + image is padded with 0's and then center cropped. + crop_size (`int`, *optional*, defaults to 768): + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying + center-cropping. Only has an effect if `do_center_crop` is set to `True`. + image_mean (`List[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=768, + resample=Image.BICUBIC, + crop_size=768, + do_center_crop=True, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): + super().__init__(**kwargs) + self.size = size + self.resample = resample + self.crop_size = crop_size + self.do_resize = do_resize + self.do_center_crop = do_center_crop + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] + self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + + def post_process(self, outputs, target_sizes): + """ + Converts the output of [`OwlViTForObjectDetection`] into the format expected by the COCO api. + + Args: + outputs ([`OwlViTObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). For visualization, this should be the image size after data + augment, but before padding. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + + return results + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W) or (H, W, C), + where C is a number of channels, H and W are image height and width. + + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize and self.size is not None and self.resample is not None: + images = [ + self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False) + for image in images + ] + if self.do_center_crop and self.crop_size is not None: + images = [self.center_crop(image, self.crop_size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py new file mode 100644 index 0000000000000..c872b1b28fe26 --- /dev/null +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -0,0 +1,1386 @@ +# coding=utf-8 +# Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OWL-ViT model.""" + + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32" + +# See all OwlViT models at https://huggingface.co/models?filter=owlvit +OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/owlvit-base-patch32", + "google/owlvit-base-patch16", + "google/owlvit-large-patch14", +] + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit +def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.T) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class OwlViTOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of + [`OwlViTVisionModel`]. + text_model_output (Tuple[`BaseModelOutputWithPooling`]): + The output of the [`OwlViTTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`OwlViTVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +@dataclass +class OwlViTObjectDetectionOutput(ModelOutput): + """ + Output type of [`OwlViTForObjectDetection`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~OwlViTFeatureExtractor.post_process`] to retrieve the unnormalized + bounding boxes. + text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`OwlViTTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`): + Pooled output of [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image patches and computes + image embeddings for each patch. + class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`): + Class embeddings of all image patches. OWL-ViT represents images as a set of image patches where the total + number of patches is (image_size / patch_size)**2. + text_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`)): + Last hidden states extracted from the [`OwlViTTextModel`]. + vision_model_last_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_patches + 1, hidden_size)`)): + Last hidden states extracted from the [`OwlViTVisionModel`]. OWL-ViT represents images as a set of image + patches where the total number of patches is (image_size / patch_size)**2. + """ + + loss: Optional[torch.FloatTensor] = None + loss_dict: Optional[Dict] = None + logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + class_embeds: torch.FloatTensor = None + text_model_last_hidden_states: Optional[torch.FloatTensor] = None + vision_model_last_hidden_states: Optional[torch.FloatTensor] = None + + +class OwlViTVisionEmbeddings(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.class_embedding = nn.Parameter(torch.randn(config.hidden_size)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.num_patches = (config.image_size // config.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [batch_size, num_channels, height, width] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class OwlViTTextEmbeddings(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class OwlViTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->OwlViT +class OwlViTMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT +class OwlViTEncoderLayer(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OwlViTAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim) + self.mlp = OwlViTMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class OwlViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = OwlViTConfig + base_model_prefix = "owlvit" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, OwlViTTextEmbeddings): + module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + elif isinstance(module, OwlViTVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, OwlViTAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, OwlViTMLP): + factor = self.config.initializer_factor + in_proj_std = ( + (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + ) + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + elif isinstance(module, OwlViTModel): + nn.init.normal_( + module.text_projection.weight, + std=module.text_embed_dim**-0.5 * self.config.initializer_factor, + ) + nn.init.normal_( + module.visual_projection.weight, + std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, + ) + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, OwlViTEncoder): + module.gradient_checkpointing = value + + +OWLVIT_START_DOCSTRING = r""" + Parameters: + This model is a PyTorch [torch.nn.Module](https: + //pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + config ([`OwlViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +OWLVIT_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`CLIPTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, num_max_text_queries, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) +""" + + +class OwlViTEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`OwlViTEncoderLayer`]. + + Args: + config: OwlViTConfig + """ + + def __init__(self, config: OwlViTConfig): + super().__init__() + self.layers = nn.ModuleList([OwlViTEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class OwlViTTextTransformer(nn.Module): + def __init__(self, config: OwlViTTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = OwlViTTextEmbeddings(config) + self.encoder = OwlViTEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + num_samples, seq_len = input_shape # num_samples = batch_size * num_max_text_queries + # OWLVIT's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(num_samples, seq_len).to(hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # take features from the end of tokens embedding (end of token is the highest number in each sequence) + pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len) + mask.fill_(torch.tensor(float("-inf"))) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class OwlViTTextModel(OwlViTPreTrainedModel): + config_class = OwlViTTextConfig + + def __init__(self, config: OwlViTTextConfig): + super().__init__(config) + self.text_model = OwlViTTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTTextConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from transformers import OwlViTProcessor, OwlViTTextModel + + >>> model = OwlViTTextModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + # Get embeddings for all text queries in all batch samples + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class OwlViTVisionTransformer(nn.Module): + def __init__(self, config: OwlViTVisionConfig): + super().__init__() + self.config = config + + self.embeddings = OwlViTVisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(config.hidden_size) + self.encoder = OwlViTEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_hidden_state: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + + if use_hidden_state: + pooled_output = self.post_layernorm(last_hidden_state) + else: + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class OwlViTVisionModel(OwlViTPreTrainedModel): + config_class = OwlViTVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: OwlViTVisionConfig): + super().__init__(config) + self.vision_model = OwlViTVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OwlViTVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import OwlViTProcessor, OwlViTVisionModel + + >>> model = OwlViTVisionModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(OWLVIT_START_DOCSTRING) +class OwlViTModel(OwlViTPreTrainedModel): + config_class = OwlViTConfig + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + if not isinstance(config.text_config, OwlViTTextConfig): + raise ValueError( + "config.text_config is expected to be of type OwlViTTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, OwlViTVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type OwlViTVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = OwlViTTextTransformer(text_config) + self.vision_model = OwlViTVisionTransformer(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(OWLVIT_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTTextModel`]. + + Examples: + ```python + >>> from transformers import OwlViTProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> inputs = processor( + ... text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt" + ... ) + >>> text_features = model.get_text_features(**inputs) + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get embeddings for all text queries in all batch samples + text_output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_output[1] + text_features = self.text_projection(pooled_output) + return text_features + + @add_start_docstrings_to_model_forward(OWLVIT_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_projected: Optional[bool] = True, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`OwlViTVisionModel`]. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import OwlViTProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> image_features = model.get_image_features(**inputs) + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + + # Return projected output + if return_projected: + image_features = self.visual_projection(pooled_output) + else: + image_features = pooled_output + return image_features + + @add_start_docstrings_to_model_forward(OWLVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OwlViTOutput, config_class=OwlViTConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, OwlViTOutput]: + r""" + Returns: + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import OwlViTProcessor, OwlViTModel + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use OWL-ViT model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_hidden_state=False, + ) + + # Get embeddings for all text queries in all batch samples + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + + loss = None + if return_loss: + loss = owlvit_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return OwlViTOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +class OwlViTBoxPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + + width = config.vision_config.hidden_size + self.dense0 = nn.Linear(width, width) + self.dense1 = nn.Linear(width, width) + self.gelu = nn.GELU() + self.dense2 = nn.Linear(width, 4) + + def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: + output = self.dense0(image_features) + output = self.gelu(output) + output = self.dense1(output) + output = self.gelu(output) + output = self.dense2(output) + return output + + +class OwlViTClassPredictionHead(nn.Module): + def __init__(self, config: OwlViTConfig): + super().__init__() + + out_dim = config.text_config.hidden_size + query_dim = config.vision_config.hidden_size + + self.dense0 = nn.Linear(query_dim, out_dim) + self.logit_shift = nn.Linear(query_dim, 1) + self.logit_scale = nn.Linear(query_dim, 1) + self.elu = nn.ELU() + + def forward( + self, + image_embeds: torch.FloatTensor, + query_embeds: torch.FloatTensor, + query_mask: torch.Tensor, + ) -> Tuple[torch.FloatTensor]: + + image_class_embeds = self.dense0(image_embeds) + + # Normalize image and text features + image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6 + query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6 + + # Get class predictions + pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) + + # Apply a learnable shift and scale to logits + logit_shift = self.logit_shift(image_embeds) + logit_scale = self.logit_scale(image_embeds) + logit_scale = self.elu(logit_scale) + 1 + pred_logits = (pred_logits + logit_shift) * logit_scale + + if query_mask is not None: + if query_mask.ndim > 1: + query_mask = torch.unsqueeze(query_mask, dim=-2) + + pred_logits = pred_logits.to(torch.float64) + pred_logits = torch.where(query_mask == 0, -1e6, pred_logits) + pred_logits = pred_logits.to(torch.float32) + + return (pred_logits, image_class_embeds) + + +class OwlViTForObjectDetection(OwlViTPreTrainedModel): + config_class = OwlViTConfig + main_input_name = "pixel_values" + + def __init__(self, config: OwlViTConfig): + super().__init__(config) + + self.owlvit = OwlViTModel(config) + self.class_head = OwlViTClassPredictionHead(config) + self.box_head = OwlViTBoxPredictionHead(config) + + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size) + self.sigmoid = nn.Sigmoid() + + def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): + # Computes normalized xy corner coordinates from feature_map. + if not feature_map.ndim == 4: + raise ValueError("Expected input shape is [batch_size, num_channels, height, width]") + + height, width = feature_map.shape[1:3] + + box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype( + np.float32 + ) + box_coordinates /= np.array([width, height], np.float32) + + # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates = box_coordinates.reshape( + box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] + ) + box_coordinates = torch.from_numpy(box_coordinates) + + return box_coordinates + + def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: + # The box center is biased to its position on the feature grid + box_coordinates = self.normalize_grid_corner_coordinates(feature_map) + box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) + + # Unnormalize xy + box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) + + # The box size is biased to the patch size + box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) + box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) + + # Compute box bias + box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias + + def box_predictor( + self, + image_feats: torch.FloatTensor, + feature_map: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Args: + image_feats: + Features extracted from the image, returned by the `image_text_embedder` method. + feature_map: + A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method. + Returns: + pred_boxes: + List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary. + """ + # Bounding box detection head [batch_size, num_boxes, 4]. + pred_boxes = self.box_head(image_feats) + + # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction + pred_boxes += self.compute_box_bias(feature_map) + pred_boxes = self.sigmoid(pred_boxes) + return pred_boxes + + def class_predictor( + self, + image_feats: torch.FloatTensor, + query_embeds: torch.FloatTensor, + query_mask: torch.Tensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """ + Args: + image_feats: + Features extracted from the `image_text_embedder`. + query_embeds: + Text query embeddings. + query_mask: + Must be provided with query_embeddings. A mask indicating which query embeddings are valid. + """ + (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask) + + return (pred_logits, image_class_embeds) + + def image_text_embedder( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + ) -> torch.FloatTensor: + # Encode text + text_embeds = self.owlvit.get_text_features( + input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions + ) + + # Encode image + image_embeds = self.owlvit.get_image_features( + pixel_values, return_projected=False, output_attentions=output_attentions + ) + + # Resize class token + new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0))) + class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size) + + # Merge image embedding with class tokens + image_embeds = image_embeds[:, 1:, :] * class_token_out + image_embeds = self.layer_norm(image_embeds) + + # Resize to [batch_size, num_patches, num_patches, hidden_size] + new_size = ( + image_embeds.shape[0], + int(np.sqrt(image_embeds.shape[1])), + int(np.sqrt(image_embeds.shape[1])), + image_embeds.shape[-1], + ) + image_embeds = image_embeds.reshape(new_size) + + return (image_embeds, text_embeds) + + @add_start_docstrings_to_model_forward(OWLVIT_OBJECT_DETECTION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=OwlViTObjectDetectionOutput, config_class=OwlViTConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> OwlViTObjectDetectionOutput: + r""" + Returns: + + Examples: + ```python + >>> import requests + >>> from PIL import Image + >>> import torch + >>> from transformers import OwlViTProcessor, OwlViTForObjectDetection + + >>> model = OwlViTModel.from_pretrained("google/owlvit-base-patch32") + >>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries] + >>> boxes = outputs["pred_boxes"] # Object box boundaries of shape # [batch_size, num_patches, 4] + + >>> batch_size = boxes.shape[0] + >>> for i in range(batch_size): # Loop over sets of images and text queries + ... boxes = outputs["pred_boxes"][i] + ... logits = torch.max(outputs["logits"][i], dim=-1) + ... scores = torch.sigmoid(logits.values) + ... labels = logits.indices + ```""" + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Return last hidden states of text and vision transformers + text_model_last_hidden_states = None + vision_model_last_hidden_states = None + + if output_hidden_states: + outputs = self.owlvit( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + text_model_last_hidden_states = outputs[-2][0] + vision_model_last_hidden_states = outputs[-1][0] + + # Embed images and text queries + feature_map, query_embeds = self.image_text_embedder( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + batch_size, height, width, hidden_dim = feature_map.shape + image_feats = torch.reshape(feature_map, (batch_size, height * width, hidden_dim)) + + # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim] + max_text_queries = input_ids.shape[0] // batch_size + query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1]) + + # If first token is 0, then this is a padded query [batch_size, num_queries]. + input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1]) + query_mask = input_ids[..., 0] > 0 + + # Predict object classes [batch_size, num_patches, num_queries+1] + (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask) + + # Predict object boxes + pred_boxes = self.box_predictor(image_feats, feature_map) + + if not return_dict: + return ( + pred_logits, + pred_boxes, + query_embeds, + feature_map, + class_embeds, + text_model_last_hidden_states, + vision_model_last_hidden_states, + ) + + return OwlViTObjectDetectionOutput( + image_embeds=feature_map, + text_embeds=query_embeds, + pred_boxes=pred_boxes, + logits=pred_logits, + class_embeds=class_embeds, + text_model_last_hidden_states=text_model_last_hidden_states, + vision_model_last_hidden_states=vision_model_last_hidden_states, + ) diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py new file mode 100644 index 0000000000000..8dc04055bbbe4 --- /dev/null +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image/Text processor class for OWL-ViT +""" +from typing import List + +import numpy as np + +from transformers import is_flax_available, is_tf_available, is_torch_available + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding + + +class OwlViTProcessor(ProcessorMixin): + r""" + Constructs an OWL-ViT processor which wraps [`OwlViTFeatureExtractor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] + into a single processor that interits both the feature extractor and tokenizer functionalities. See the + [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. + + Args: + feature_extractor ([`OwlViTFeatureExtractor`]): + The feature extractor is a required input. + tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]): + The tokenizer is a required input. + """ + feature_extractor_class = "OwlViTFeatureExtractor" + tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast") + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs): + """ + Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and + `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the + doctsring of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify at least one text or image. Both cannot be none.") + + if text is not None: + if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): + encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + + elif isinstance(text, List) and isinstance(text[0], List): + encodings = [] + + # Maximum number of queries across batch + max_num_queries = max([len(t) for t in text]) + + # Pad all batch samples to max number of text queries + for t in text: + if len(t) != max_num_queries: + t = t + [" "] * (max_num_queries - len(t)) + + encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encodings.append(encoding) + else: + raise TypeError("Input text should be a string, a list of strings or a nested list of strings") + + if return_tensors == "np": + input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "jax" and is_flax_available(): + import jax.numpy as jnp + + input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "pt" and is_torch_available(): + import torch + + input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) + attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) + + elif return_tensors == "tf" and is_tf_available(): + import tensorflow as tf + + input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) + + else: + raise ValueError("Target return tensor type could not be returned") + + encoding = BatchEncoding() + encoding["input_ids"] = input_ids + encoding["attention_mask"] = attention_mask + + if images is not None: + image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index b1e97d3acf6e5..043f4cdb9dd92 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3459,6 +3459,44 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class OwlViTForObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OwlViTModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OwlViTPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OwlViTTextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OwlViTVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class PegasusForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 9c3946a914894..e5d2bced9e041 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -122,6 +122,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class OwlViTFeatureExtractor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class PerceiverFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/owlvit/__init__.py b/tests/models/owlvit/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/owlvit/test_feature_extraction_owlvit.py b/tests/models/owlvit/test_feature_extraction_owlvit.py new file mode 100644 index 0000000000000..c9198280d792e --- /dev/null +++ b/tests/models/owlvit/test_feature_extraction_owlvit.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import OwlViTFeatureExtractor + + +class OwlViTFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=20, + do_center_crop=True, + crop_size=18, + do_normalize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + do_convert_rgb=True, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + +@require_torch +@require_vision +class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = OwlViTFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = OwlViTFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "center_crop")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py new file mode 100644 index 0000000000000..b45d37dd81802 --- /dev/null +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -0,0 +1,815 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch OwlViT model. """ + + +import inspect +import os +import tempfile +import unittest +from typing import Dict, List, Tuple + +import numpy as np + +import requests +from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + ModelTesterMixin, + _config_zero_init, + floats_tensor, + ids_tensor, + random_attention_mask, +) + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import OwlViTForObjectDetection, OwlViTModel, OwlViTTextModel, OwlViTVisionModel + from transformers.models.owlvit.modeling_owlvit import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import OwlViTProcessor + + +class OwlViTVisionModelTester: + def __init__( + self, + parent, + batch_size=12, + image_size=32, + patch_size=2, + num_channels=3, + is_training=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.scope = scope + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return OwlViTVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values): + model = OwlViTVisionModel(config=config) + model.to(torch_device) + model.eval() + + pixel_values = pixel_values.to(torch.float32) + + with torch.no_grad(): + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_patches = (self.image_size // self.patch_size) ** 2 + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as OWLVIT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (OwlViTVisionModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = OwlViTVisionModelTester(self) + self.config_tester = ConfigTester( + self, config_class=OwlViTVisionConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="OWLVIT does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training(self): + pass + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = OwlViTVisionModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class OwlViTTextModelTester: + def __init__( + self, + parent, + batch_size=12, + num_queries=4, + seq_length=16, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=64, + num_hidden_layers=12, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + max_position_embeddings=16, + initializer_range=0.02, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.num_queries = num_queries + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size * self.num_queries, self.seq_length], self.vocab_size) + input_mask = None + + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size * self.num_queries, self.seq_length]) + + if input_mask is not None: + num_text, seq_length = input_mask.shape + + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(num_text,)) + for idx, start_index in enumerate(rnd_start_indices): + input_mask[idx, :start_index] = 1 + input_mask[idx, start_index:] = 0 + + config = self.get_config() + + return config, input_ids, input_mask + + def get_config(self): + return OwlViTTextConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, input_ids, input_mask): + model = OwlViTTextModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(input_ids=input_ids, attention_mask=input_mask) + + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size * self.num_queries, self.seq_length, self.hidden_size) + ) + self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_queries, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, input_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (OwlViTTextModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_head_masking = False + + def setUp(self): + self.model_tester = OwlViTTextModelTester(self) + self.config_tester = ConfigTester(self, config_class=OwlViTTextConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training(self): + pass + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="OWLVIT does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = OwlViTTextModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class OwlViTModelTester: + def __init__(self, parent, is_training=True): + self.parent = parent + self.text_model_tester = OwlViTTextModelTester(parent) + self.vision_model_tester = OwlViTVisionModelTester(parent) + self.is_training = is_training + self.text_config = self.text_model_tester.get_config().to_dict() + self.vision_config = self.vision_model_tester.get_config().to_dict() + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + config = self.get_config() + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = OwlViTModel(config).to(torch_device).eval() + + with torch.no_grad(): + result = model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + ) + + image_logits_size = ( + self.vision_model_tester.batch_size, + self.text_model_tester.batch_size * self.text_model_tester.num_queries, + ) + text_logits_size = ( + self.text_model_tester.batch_size * self.text_model_tester.num_queries, + self.vision_model_tester.batch_size, + ) + self.parent.assertEqual(result.logits_per_image.shape, image_logits_size) + self.parent.assertEqual(result.logits_per_text.shape, text_logits_size) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "return_loss": False, + } + return config, inputs_dict + + +@require_torch +class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (OwlViTModel,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + + def setUp(self): + self.model_tester = OwlViTModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="OwlViTModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + # override as the `logit_scale` parameter initilization is different for OWLVIT + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + + try: + input_ids = inputs_dict["input_ids"] + pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values + traced_model = torch.jit.trace(model, (input_ids, pixel_values)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_load_vision_text_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Save OwlViTConfig and check if we can load OwlViTVisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = OwlViTVisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save OwlViTConfig and check if we can load OwlViTTextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = OwlViTTextConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = OwlViTModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class OwlViTForObjectDetectionTester: + def __init__(self, parent, is_training=True): + self.parent = parent + self.text_model_tester = OwlViTTextModelTester(parent) + self.vision_model_tester = OwlViTVisionModelTester(parent) + self.is_training = is_training + self.text_config = self.text_model_tester.get_config().to_dict() + self.vision_config = self.vision_model_tester.get_config().to_dict() + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + config = self.get_config() + return config, pixel_values, input_ids, attention_mask + + def get_config(self): + return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64) + + def create_and_check_model(self, config, pixel_values, input_ids, attention_mask): + model = OwlViTForObjectDetection(config).to(torch_device).eval() + with torch.no_grad(): + result = model( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + ) + + pred_boxes_size = ( + self.vision_model_tester.batch_size, + (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2, + 4, + ) + pred_logits_size = ( + self.vision_model_tester.batch_size, + (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2, + 4, + ) + pred_class_embeds_size = ( + self.vision_model_tester.batch_size, + (self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2, + self.text_model_tester.hidden_size, + ) + self.parent.assertEqual(result.pred_boxes.shape, pred_boxes_size) + self.parent.assertEqual(result.logits.shape, pred_logits_size) + self.parent.assertEqual(result.class_embeds.shape, pred_class_embeds_size) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, input_ids, attention_mask = config_and_inputs + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (OwlViTForObjectDetection,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + + def setUp(self): + self.model_tester = OwlViTForObjectDetectionTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="OwlViTModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Test_initialization is tested in individual model tests") + def test_initialization(self): + pass + + @unittest.skip(reason="Test_forward_signature is tested in individual model tests") + def test_forward_signature(self): + pass + + @unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training(self): + pass + + @unittest.skip(reason="OWL-ViT does not support training yet") + def test_training_gradient_checkpointing(self): + pass + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + + try: + input_ids = inputs_dict["input_ids"] + pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values + traced_model = torch.jit.trace(model, (input_ids, pixel_values)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = OwlViTForObjectDetection.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@require_vision +@require_torch +class OwlViTModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model_name = "google/owlvit-base-patch32" + model = OwlViTModel.from_pretrained(model_name).to(torch_device) + processor = OwlViTProcessor.from_pretrained(model_name) + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + self.assertEqual( + outputs.logits_per_image.shape, + torch.Size( + ( + inputs.pixel_values.shape[0], + inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0], + ) + ), + ) + self.assertEqual( + outputs.logits_per_text.shape, + torch.Size( + ( + inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0], + inputs.pixel_values.shape[0], + ) + ), + ) + + expected_logits = torch.tensor([[1.0115, 0.9982]], device=torch_device) + + self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) + + @slow + def test_inference_object_detection(self): + model_name = "google/owlvit-base-patch32" + model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device) + + processor = OwlViTProcessor.from_pretrained(model_name) + + image = prepare_img() + inputs = processor( + text=[["a photo of a cat", "a photo of a dog"]], + images=image, + max_length=16, + padding="max_length", + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) + self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) + expected_slice_boxes = torch.tensor( + [[0.0143, 0.0236, 0.0285], [0.0649, 0.0247, 0.0437], [0.0601, 0.0446, 0.0699]] + ).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) diff --git a/tests/models/owlvit/test_processor_owlvit.py b/tests/models/owlvit/test_processor_owlvit.py new file mode 100644 index 0000000000000..e37f45b15c8bb --- /dev/null +++ b/tests/models/owlvit/test_processor_owlvit.py @@ -0,0 +1,241 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np +import pytest + +from transformers import CLIPTokenizer, CLIPTokenizerFast +from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES +from transformers.testing_utils import require_vision +from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available + + +if is_vision_available(): + from PIL import Image + + from transformers import OwlViTFeatureExtractor, OwlViTProcessor + + +@require_vision +class OwlViTProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + # fmt: off + vocab = ["", "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l", "w", "r", "t", "low", "er", "lowest", "newer", "wider", "", "<|startoftext|>", "<|endoftext|>"] + # fmt: on + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "l o", "lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + feature_extractor_map = { + "do_resize": True, + "size": 20, + "do_center_crop": True, + "crop_size": 18, + "do_normalize": True, + "image_mean": [0.48145466, 0.4578275, 0.40821073], + "image_std": [0.26862954, 0.26130258, 0.27577711], + } + self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.feature_extractor_file, "w", encoding="utf-8") as fp: + json.dump(feature_extractor_map, fp) + + def get_tokenizer(self, **kwargs): + return CLIPTokenizer.from_pretrained(self.tmpdirname, pad_token="!", **kwargs) + + def get_rust_tokenizer(self, **kwargs): + return CLIPTokenizerFast.from_pretrained(self.tmpdirname, pad_token="!", **kwargs) + + def get_feature_extractor(self, **kwargs): + return OwlViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_default(self): + tokenizer_slow = self.get_tokenizer() + tokenizer_fast = self.get_rust_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor_slow = OwlViTProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor) + processor_slow.save_pretrained(self.tmpdirname) + processor_slow = OwlViTProcessor.from_pretrained(self.tmpdirname, use_fast=False) + + processor_fast = OwlViTProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor) + processor_fast.save_pretrained(self.tmpdirname) + processor_fast = OwlViTProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab()) + self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab()) + self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab()) + self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer) + self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast) + + self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor_slow.feature_extractor, OwlViTFeatureExtractor) + self.assertIsInstance(processor_fast.feature_extractor, OwlViTFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = OwlViTProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False) + + processor = OwlViTProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, OwlViTFeatureExtractor) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = feature_extractor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + + encoded_processor = processor(text=input_str, return_tensors="np") + + encoded_tok = tokenizer(input_str, return_tensors="np") + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key][0].tolist(), encoded_processor[key][0].tolist()) + + def test_processor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_processor_with_text_list(self): + model_name = "google/owlvit-base-patch32" + processor = OwlViTProcessor.from_pretrained(model_name) + + input_text = ["cat", "nasa badge"] + inputs = processor(text=input_text) + + seq_length = 16 + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"]) + self.assertEqual(inputs["input_ids"].shape, (2, seq_length)) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_processor_with_nested_text_list(self): + model_name = "google/owlvit-base-patch32" + processor = OwlViTProcessor.from_pretrained(model_name) + + input_texts = [["cat", "nasa badge"], ["person"]] + inputs = processor(text=input_texts) + + seq_length = 16 + batch_size = len(input_texts) + num_max_text_queries = max([len(texts) for texts in input_texts]) + + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"]) + self.assertEqual(inputs["input_ids"].shape, (batch_size * num_max_text_queries, seq_length)) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + + def test_processor_case(self): + model_name = "google/owlvit-base-patch32" + processor = OwlViTProcessor.from_pretrained(model_name) + + input_texts = ["cat", "nasa badge"] + inputs = processor(text=input_texts) + + seq_length = 16 + input_ids = inputs["input_ids"] + predicted_ids = [ + [49406, 2368, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [49406, 6841, 11301, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + + self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"]) + self.assertEqual(inputs["input_ids"].shape, (2, seq_length)) + self.assertListEqual(list(input_ids[0]), predicted_ids[0]) + self.assertListEqual(list(input_ids[1]), predicted_ids[1]) + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index d7fb8acd5c530..bcbbace39e0e7 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -41,6 +41,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { "CLIPConfig", + "OwlViTConfig", "GroupViTConfig", "DecisionTransformerConfig", "EncoderDecoderConfig", diff --git a/utils/check_repo.py b/utils/check_repo.py index 90d4218bdd6e6..47fee163137f1 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -166,6 +166,9 @@ "LukeForEntityPairClassification", "LukeForEntitySpanClassification", "OpenAIGPTDoubleHeadsModel", + "OwlViTTextModel", + "OwlViTVisionModel", + "OwlViTForObjectDetection", "RagModel", "RagSequenceForGeneration", "RagTokenForGeneration",