From 2ab790e82d0759b667cd848a4d49e6ad65e15d59 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:40:58 +0200 Subject: [PATCH] Add Donut (#18488) * First draft * Improve script * Update script * Make conversion work * Add final_layer_norm attribute to Swin's config * Add DonutProcessor * Convert more models * Improve feature extractor and convert base models * Fix bug * Improve integration tests * Improve integration tests and add model to README * Add doc test * Add feature extractor to docs * Fix integration tests * Remove register_buffer * Fix toctree and add missing attribute * Add DonutSwin * Make conversion script work * Improve conversion script * Address comment * Fix bug * Fix another bug * Remove deprecated method from docs * Make Swin and Swinv2 untouched * Fix code examples * Fix processor * Update model_type to donut-swin * Add feature extractor tests, add token2json method, improve feature extractor * Fix failing tests, remove integration test * Add do_thumbnail for consistency * Improve code examples * Add code example for document parsing * Add DonutSwin to MODEL_NAMES_MAPPING * Add model to appropriate place in toctree * Update namespace to appropriate organization Co-authored-by: Niels Rogge --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.mdx | 2 + docs/source/en/model_doc/donut.mdx | 214 ++++ src/transformers/__init__.py | 12 + src/transformers/image_utils.py | 22 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 5 + .../models/auto/feature_extraction_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/processing_auto.py | 1 + src/transformers/models/donut/__init__.py | 76 ++ .../models/donut/configuration_donut_swin.py | 140 +++ .../models/donut/convert_donut_to_pytorch.py | 234 +++++ .../models/donut/feature_extraction_donut.py | 208 ++++ .../models/donut/modeling_donut_swin.py | 941 ++++++++++++++++++ .../models/donut/processing_donut.py | 156 +++ .../convert_trocr_unilm_to_pytorch.py | 0 src/transformers/utils/dummy_pt_objects.py | 17 + .../utils/dummy_vision_objects.py | 7 + src/transformers/utils/fx.py | 1 + tests/models/donut/__init__.py | 0 .../donut/test_feature_extraction_donut.py | 203 ++++ .../models/donut/test_modeling_donut_swin.py | 464 +++++++++ .../test_modeling_vision_encoder_decoder.py | 216 +++- utils/check_copies.py | 1 + utils/check_repo.py | 1 + utils/documentation_tests.txt | 1 + 31 files changed, 2924 insertions(+), 7 deletions(-) create mode 100644 docs/source/en/model_doc/donut.mdx create mode 100644 src/transformers/models/donut/__init__.py create mode 100644 src/transformers/models/donut/configuration_donut_swin.py create mode 100644 src/transformers/models/donut/convert_donut_to_pytorch.py create mode 100644 src/transformers/models/donut/feature_extraction_donut.py create mode 100644 src/transformers/models/donut/modeling_donut_swin.py create mode 100644 src/transformers/models/donut/processing_donut.py rename src/transformers/models/{vision_encoder_decoder => trocr}/convert_trocr_unilm_to_pytorch.py (100%) create mode 100644 tests/models/donut/__init__.py create mode 100644 tests/models/donut/test_feature_extraction_donut.py create mode 100644 tests/models/donut/test_modeling_donut_swin.py diff --git a/README.md b/README.md index 46a4b07c14cd3..30bc6d870bbf0 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei. +1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. diff --git a/README_ko.md b/README_ko.md index c63fdca749da8..cc0b790ad76a8 100644 --- a/README_ko.md +++ b/README_ko.md @@ -242,6 +242,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT. 1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei. +1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER) released with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. diff --git a/README_zh-hans.md b/README_zh-hans.md index 0ab06bd96ad99..fe2fa45f71f39 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -266,6 +266,7 @@ conda install -c huggingface transformers 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (来自 Microsoft Research) 伴随论文 [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) 由 Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan 发布。 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (来自 HuggingFace), 伴随论文 [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 同样的方法也应用于压缩 GPT-2 到 [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa 到 [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT 到 [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) 和德语版 DistilBERT。 1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (来自 Microsoft Research) 伴随论文 [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) 由 Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei 发布。 +1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (来自 NAVER) 伴随论文 [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) 由 Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park 发布。 1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (来自 Facebook) 伴随论文 [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) 由 Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih 发布。 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (来自 Intel Labs) 伴随论文 [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) 由 René Ranftl, Alexey Bochkovskiy, Vladlen Koltun 发布。 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (来自 Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 90f29ad031b8b..4f5a995476149 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -278,6 +278,7 @@ conda install -c huggingface transformers 1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT. 1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei. +1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER) released with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 32ab4c6361d3a..78137d2c8a74c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -427,6 +427,8 @@ title: CLIP - local: model_doc/data2vec title: Data2Vec + - local: model_doc/donut + title: Donut - local: model_doc/flava title: FLAVA - local: model_doc/groupvit diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 5c0d51d8b7afb..257eba8171ed1 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -84,6 +84,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[DialoGPT](model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT. 1. **[DiT](model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei. +1. **[Donut](model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. 1. **[DPR](model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[DPT](master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun. 1. **[ELECTRA](model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. @@ -224,6 +225,7 @@ Flax), PyTorch, and/or TensorFlow. | DeiT | ❌ | ❌ | ✅ | ✅ | ❌ | | DETR | ❌ | ❌ | ✅ | ❌ | ❌ | | DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ | +| DonutSwin | ❌ | ❌ | ✅ | ❌ | ❌ | | DPR | ✅ | ✅ | ✅ | ✅ | ❌ | | DPT | ❌ | ❌ | ✅ | ❌ | ❌ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/donut.mdx b/docs/source/en/model_doc/donut.mdx new file mode 100644 index 0000000000000..9c9973be022e7 --- /dev/null +++ b/docs/source/en/model_doc/donut.mdx @@ -0,0 +1,214 @@ + + +# Donut + +## Overview + +The Donut model was proposed in [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by +Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park. +Donut consists of an image Transformer encoder and an autoregressive text Transformer decoder to perform document understanding +tasks such as document image classification, form understanding and visual question answering. + +The abstract from the paper is the following: + +*Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.* + + + + Donut high-level overview. Taken from the original paper. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found +[here](https://github.com/clovaai/donut). + +Tips: + +- The quickest way to get started with Donut is by checking the [tutorial + notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/donut), which show how to use the model + at inference time as well as fine-tuning on custom data. +- Donut is always used within the [VisionEncoderDecoder](vision-encoder-decoder) framework. + +## Inference + +Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of +[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image. + +The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and +[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The +[`DonutProcessor`] wraps [`DonutFeatureExtractor`] and [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] +into a single instance to both extract the input features and decode the predicted token ids. + +- Step-by-step Document Image Classification + +```py +>>> import re + +>>> from transformers import DonutProcessor, VisionEncoderDecoderModel +>>> from datasets import load_dataset +>>> import torch + +>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") +>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") + +>>> device = "cuda" if torch.cuda.is_available() else "cpu" +>>> model.to(device) # doctest: +IGNORE_RESULT + +>>> # load document image +>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test") +>>> image = dataset[1]["image"] + +>>> # prepare decoder inputs +>>> task_prompt = "" +>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids + +>>> pixel_values = processor(image, return_tensors="pt").pixel_values + +>>> outputs = model.generate( +... pixel_values.to(device), +... decoder_input_ids=decoder_input_ids.to(device), +... max_length=model.decoder.config.max_position_embeddings, +... early_stopping=True, +... pad_token_id=processor.tokenizer.pad_token_id, +... eos_token_id=processor.tokenizer.eos_token_id, +... use_cache=True, +... num_beams=1, +... bad_words_ids=[[processor.tokenizer.unk_token_id]], +... return_dict_in_generate=True, +... ) + +>>> sequence = processor.batch_decode(outputs.sequences)[0] +>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") +>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token +>>> print(processor.token2json(sequence)) +{'class': 'advertisement'} +``` + +- Step-by-step Document Parsing + +```py +>>> import re + +>>> from transformers import DonutProcessor, VisionEncoderDecoderModel +>>> from datasets import load_dataset +>>> import torch + +>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") +>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") + +>>> device = "cuda" if torch.cuda.is_available() else "cpu" +>>> model.to(device) # doctest: +IGNORE_RESULT + +>>> # load document image +>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test") +>>> image = dataset[2]["image"] + +>>> # prepare decoder inputs +>>> task_prompt = "" +>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids + +>>> pixel_values = processor(image, return_tensors="pt").pixel_values + +>>> outputs = model.generate( +... pixel_values.to(device), +... decoder_input_ids=decoder_input_ids.to(device), +... max_length=model.decoder.config.max_position_embeddings, +... early_stopping=True, +... pad_token_id=processor.tokenizer.pad_token_id, +... eos_token_id=processor.tokenizer.eos_token_id, +... use_cache=True, +... num_beams=1, +... bad_words_ids=[[processor.tokenizer.unk_token_id]], +... return_dict_in_generate=True, +... ) + +>>> sequence = processor.batch_decode(outputs.sequences)[0] +>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") +>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token +>>> print(processor.token2json(sequence)) +{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}} +``` + +- Step-by-step Document Visual Question Answering (DocVQA) + +```py +>>> import re + +>>> from transformers import DonutProcessor, VisionEncoderDecoderModel +>>> from datasets import load_dataset +>>> import torch + +>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") +>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") + +>>> device = "cuda" if torch.cuda.is_available() else "cpu" +>>> model.to(device) # doctest: +IGNORE_RESULT + +>>> # load document image from the DocVQA dataset +>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test") +>>> image = dataset[0]["image"] + +>>> # prepare decoder inputs +>>> task_prompt = "{user_input}" +>>> question = "When is the coffee break?" +>>> prompt = task_prompt.replace("{user_input}", question) +>>> decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids + +>>> pixel_values = processor(image, return_tensors="pt").pixel_values + +>>> outputs = model.generate( +... pixel_values.to(device), +... decoder_input_ids=decoder_input_ids.to(device), +... max_length=model.decoder.config.max_position_embeddings, +... early_stopping=True, +... pad_token_id=processor.tokenizer.pad_token_id, +... eos_token_id=processor.tokenizer.eos_token_id, +... use_cache=True, +... num_beams=1, +... bad_words_ids=[[processor.tokenizer.unk_token_id]], +... return_dict_in_generate=True, +... ) + +>>> sequence = processor.batch_decode(outputs.sequences)[0] +>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") +>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token +>>> print(processor.token2json(sequence)) +{'question': 'When is the coffee break?', 'answer': '11-14 to 11:39 a.m.'} +``` + +See the [model hub](https://huggingface.co/models?filter=donut) to look for Donut checkpoints. + +## Training + +We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/donut). + +## DonutSwinConfig + +[[autodoc]] DonutSwinConfig + +## DonutFeatureExtractor + +[[autodoc]] DonutFeatureExtractor + - __call__ + +## DonutProcessor + +[[autodoc]] DonutProcessor + - __call__ + - from_pretrained + - save_pretrained + - batch_decode + - decode + +## DonutSwinModel + +[[autodoc]] DonutSwinModel + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2f53db07f078f..d6444e0844ff5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -190,6 +190,7 @@ "models.dialogpt": [], "models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"], "models.dit": [], + "models.donut": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutProcessor", "DonutSwinConfig"], "models.dpr": [ "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DPRConfig", @@ -641,6 +642,7 @@ _import_structure["models.convnext"].append("ConvNextFeatureExtractor") _import_structure["models.deit"].append("DeiTFeatureExtractor") _import_structure["models.detr"].append("DetrFeatureExtractor") + _import_structure["models.donut"].append("DonutFeatureExtractor") _import_structure["models.dpt"].append("DPTFeatureExtractor") _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor"]) _import_structure["models.glpn"].append("GLPNFeatureExtractor") @@ -1099,6 +1101,13 @@ "DistilBertPreTrainedModel", ] ) + _import_structure["models.donut"].extend( + [ + "DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", + "DonutSwinModel", + "DonutSwinPreTrainedModel", + ] + ) _import_structure["models.dpr"].extend( [ "DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2984,6 +2993,7 @@ from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer + from .models.donut import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutProcessor, DonutSwinConfig from .models.dpr import ( DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig, @@ -3375,6 +3385,7 @@ from .models.convnext import ConvNextFeatureExtractor from .models.deit import DeiTFeatureExtractor from .models.detr import DetrFeatureExtractor + from .models.donut import DonutFeatureExtractor from .models.dpt import DPTFeatureExtractor from .models.flava import FlavaFeatureExtractor, FlavaProcessor from .models.glpn import GLPNFeatureExtractor @@ -3761,6 +3772,7 @@ DistilBertModel, DistilBertPreTrainedModel, ) + from .models.donut import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, DonutSwinModel, DonutSwinPreTrainedModel from .models.dpr import ( DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index dd7bb326993d3..e5a395341c003 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -376,3 +376,25 @@ def flip_channel_order(self, image): image = self.to_numpy_array(image) return image[::-1, :, :] + + def rotate(self, image, angle, resample=PIL.Image.NEAREST, expand=0, center=None, translate=None, fillcolor=None): + """ + Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees + counter clockwise around its centre. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before + rotating. + + Returns: + image: A rotated `PIL.Image.Image`. + """ + self._ensure_format_supported(image) + + if not isinstance(image, PIL.Image.Image): + image = self.to_pil_image(image) + + return image.rotate( + angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor + ) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 11887db91f839..fdf315b2257d8 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -52,6 +52,7 @@ dialogpt, distilbert, dit, + donut, dpr, dpt, electra, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c65a2762a0002..c9e6156a3843d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -56,6 +56,7 @@ ("deit", "DeiTConfig"), ("detr", "DetrConfig"), ("distilbert", "DistilBertConfig"), + ("donut-swin", "DonutSwinConfig"), ("dpr", "DPRConfig"), ("dpt", "DPTConfig"), ("electra", "ElectraConfig"), @@ -181,6 +182,7 @@ ("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("donut-swin", "DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -304,6 +306,8 @@ ("dialogpt", "DialoGPT"), ("distilbert", "DistilBERT"), ("dit", "DiT"), + ("donut", "Donut"), + ("donut-swin", "DonutSwin"), ("dpr", "DPR"), ("dpt", "DPT"), ("electra", "ELECTRA"), @@ -420,6 +424,7 @@ ("data2vec-audio", "data2vec"), ("data2vec-text", "data2vec"), ("data2vec-vision", "data2vec"), + ("donut-swin", "donut"), ] ) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index db581d03d8fb7..5c5f86d040c8f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -46,6 +46,7 @@ ("deit", "DeiTFeatureExtractor"), ("detr", "DetrFeatureExtractor"), ("detr", "DetrFeatureExtractor"), + ("donut", "DonutFeatureExtractor"), ("dpt", "DPTFeatureExtractor"), ("flava", "FlavaFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bd4774c245b07..0e026cb48d0c0 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -56,6 +56,7 @@ ("deit", "DeiTModel"), ("detr", "DetrModel"), ("distilbert", "DistilBertModel"), + ("donut-swin", "DonutSwinModel"), ("dpr", "DPRQuestionEncoder"), ("dpt", "DPTModel"), ("electra", "ElectraModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index aed7b4b976137..c6f4fd98316a4 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -38,6 +38,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("clip", "CLIPProcessor"), + ("donut", "DonutProcessor"), ("flava", "FlavaProcessor"), ("groupvit", "CLIPProcessor"), ("layoutlmv2", "LayoutLMv2Processor"), diff --git a/src/transformers/models/donut/__init__.py b/src/transformers/models/donut/__init__.py new file mode 100644 index 0000000000000..a01f6b11a9a99 --- /dev/null +++ b/src/transformers/models/donut/__init__.py @@ -0,0 +1,76 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig"], + "processing_donut": ["DonutProcessor"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_donut_swin"] = [ + "DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", + "DonutSwinModel", + "DonutSwinPreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"] + + +if TYPE_CHECKING: + from .configuration_donut_swin import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutSwinConfig + from .processing_donut import DonutProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_donut_swin import ( + DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, + DonutSwinModel, + DonutSwinPreTrainedModel, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_donut import DonutFeatureExtractor + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/donut/configuration_donut_swin.py b/src/transformers/models/donut/configuration_donut_swin.py new file mode 100644 index 0000000000000..d3316bdc79f68 --- /dev/null +++ b/src/transformers/models/donut/configuration_donut_swin.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Donut Swin Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "naver-clova-ix/donut-base": "https://huggingface.co/naver-clova-ix/donut-base/resolve/main/config.json", + # See all Donut models at https://huggingface.co/models?filter=donut-swin +} + + +class DonutSwinConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DonutSwinModel`]. It is used to instantiate a + Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Donut + [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to [2, 2, 6, 2]): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to [3, 6, 12, 24]): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to True): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to False): + Whether or not to add absolute position embeddings to the patch embeddings. + patch_norm (`bool`, *optional*, defaults to True): + Whether or not to add layer normalization after patch embedding. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import DonutSwinConfig, DonutSwinModel + + >>> # Initializing a Donut naver-clova-ix/donut-base style configuration + >>> configuration = DonutSwinConfig() + + >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration + >>> model = DonutSwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + **kwargs + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.path_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) diff --git a/src/transformers/models/donut/convert_donut_to_pytorch.py b/src/transformers/models/donut/convert_donut_to_pytorch.py new file mode 100644 index 0000000000000..507f10cb776cf --- /dev/null +++ b/src/transformers/models/donut/convert_donut_to_pytorch.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Donut checkpoints using the original `donut-python` library. URL: https://github.com/clovaai/donut""" + +import argparse + +import torch +from datasets import load_dataset + +from donut import DonutModel +from transformers import ( + DonutFeatureExtractor, + DonutProcessor, + DonutSwinConfig, + DonutSwinModel, + MBartConfig, + MBartForCausalLM, + VisionEncoderDecoderModel, + XLMRobertaTokenizerFast, +) + + +def get_configs(model): + original_config = model.config + + encoder_config = DonutSwinConfig( + image_size=original_config.input_size, + patch_size=4, + depths=original_config.encoder_layer, + num_heads=[4, 8, 16, 32], + window_size=original_config.window_size, + embed_dim=128, + ) + decoder_config = MBartConfig( + is_decoder=True, + is_encoder_decoder=False, + add_cross_attention=True, + decoder_layers=original_config.decoder_layer, + max_position_embeddings=original_config.max_position_embeddings, + vocab_size=len( + model.decoder.tokenizer + ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json) + scale_embedding=True, + add_final_layer_norm=True, + ) + + return encoder_config, decoder_config + + +def rename_key(name): + if "encoder.model" in name: + name = name.replace("encoder.model", "encoder") + if "decoder.model" in name: + name = name.replace("decoder.model", "decoder") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if name.startswith("encoder"): + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name and "mask" not in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + + if name == "encoder.norm.weight": + name = "encoder.layernorm.weight" + if name == "encoder.norm.bias": + name = "encoder.layernorm.bias" + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[3]) + block_num = int(key_split[5]) + dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight" + ] = val[dim : dim * 2, :] + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias" + ] = val[:dim] + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias" + ] = val[dim : dim * 2] + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias" + ] = val[-dim:] + elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]: + # HuggingFace implementation doesn't use attn_mask buffer + # and model doesn't use final LayerNorms for the encoder + pass + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): + # load original model + original_model = DonutModel.from_pretrained(model_name).eval() + + # load HuggingFace model + encoder_config, decoder_config = get_configs(original_model) + encoder = DonutSwinModel(encoder_config) + decoder = MBartForCausalLM(decoder_config) + model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + model.eval() + + state_dict = original_model.state_dict() + new_state_dict = convert_state_dict(state_dict, model) + model.load_state_dict(new_state_dict) + + # verify results on scanned document + dataset = load_dataset("hf-internal-testing/example-documents") + image = dataset["test"][0]["image"].convert("RGB") + + tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name, from_slow=True) + feature_extractor = DonutFeatureExtractor( + do_align_long_axis=original_model.config.align_long_axis, size=original_model.config.input_size[::-1] + ) + processor = DonutProcessor(feature_extractor, tokenizer) + pixel_values = processor(image, return_tensors="pt").pixel_values + + if model_name == "naver-clova-ix/donut-base-finetuned-docvqa": + task_prompt = "{user_input}" + question = "When is the coffee break?" + task_prompt = task_prompt.replace("{user_input}", question) + elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip": + task_prompt = "" + elif model_name in [ + "naver-clova-ix/donut-base-finetuned-cord-v1", + "naver-clova-ix/donut-base-finetuned-cord-v1-2560", + ]: + task_prompt = "" + elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2": + task_prompt = "s_cord-v2>" + elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket": + task_prompt = "" + elif model_name in ["naver-clova-ix/donut-proto", "naver-clova-ix/donut-base"]: + # use a random prompt + task_prompt = "hello world" + else: + raise ValueError("Model name not supported") + prompt_tensors = original_model.decoder.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ] + + original_patch_embed = original_model.encoder.model.patch_embed(pixel_values) + patch_embeddings, _ = model.encoder.embeddings(pixel_values) + assert torch.allclose(original_patch_embed, patch_embeddings, atol=1e-3) + + # verify encoder hidden states + original_last_hidden_state = original_model.encoder(pixel_values) + last_hidden_state = model.encoder(pixel_values).last_hidden_state + assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2) + + # verify decoder hidden states + original_logits = original_model(pixel_values, prompt_tensors, None).logits + logits = model(pixel_values, decoder_input_ids=prompt_tensors).logits + assert torch.allclose(original_logits, logits, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model") + processor.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="naver-clova-ix/donut-base-finetuned-docvqa", + required=False, + type=str, + help="Name of the original model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + required=False, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the converted model and processor to the 🤗 hub.", + ) + + args = parser.parse_args() + convert_donut_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/donut/feature_extraction_donut.py b/src/transformers/models/donut/feature_extraction_donut.py new file mode 100644 index 0000000000000..09bf3a6ad1c15 --- /dev/null +++ b/src/transformers/models/donut/feature_extraction_donut.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Donut.""" + +from typing import Optional, Tuple, Union + +import numpy as np +from PIL import Image, ImageOps + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageFeatureExtractionMixin, + ImageInput, + is_torch_tensor, +) +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a Donut feature extractor. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the shorter edge of the input to the minimum value of a certain `size`. + size (`Tuple(int)`, *optional*, defaults to [1920, 2560]): + Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width, + height). Only has an effect if `do_resize` is set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + do_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to thumbnail the input to the given `size`. + do_align_long_axis (`bool`, *optional*, defaults to `False`): + Whether to rotate the input if the height is greater than width. + do_pad (`bool`, *optional*, defaults to `True`): + Whether or not to pad the input to `size`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=[1920, 2560], + resample=Image.BILINEAR, + do_thumbnail=True, + do_align_long_axis=False, + do_pad=True, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_thumbnail = do_thumbnail + self.do_align_long_axis = do_align_long_axis + self.do_pad = do_pad + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def rotate_image(self, image, size): + if not isinstance(image, Image.Image): + image = self.to_pil_image(image) + + if (size[1] > size[0] and image.width > image.height) or (size[1] < size[0] and image.width < image.height): + image = self.rotate(image, angle=-90, expand=True) + + return image + + def thumbnail(self, image, size): + if not isinstance(image, Image.Image): + image = self.to_pil_image(image) + + image.thumbnail((size[0], size[1])) + + return image + + def pad(self, image: Image.Image, size: Tuple[int, int], random_padding: bool = False) -> Image.Image: + delta_width = size[0] - image.width + delta_height = size[1] - image.height + + if random_padding: + pad_width = np.random.randint(low=0, high=delta_width + 1) + pad_height = np.random.randint(low=0, high=delta_height + 1) + else: + pad_width = delta_width // 2 + pad_height = delta_height // 2 + + padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) + return ImageOps.expand(image, padding) + + def __call__( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = None, + random_padding=False, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + random_padding (`bool`, *optional*, defaults to `False`): + Whether to randomly pad the input to `size`. + + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (rotating + resizing + thumbnailing + padding + normalization) + if self.do_align_long_axis: + images = [self.rotate_image(image, self.size) for image in images] + if self.do_resize and self.size is not None: + images = [ + self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False) + for image in images + ] + if self.do_thumbnail and self.size is not None: + images = [self.thumbnail(image=image, size=self.size) for image in images] + if self.do_pad and self.size is not None: + images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py new file mode 100644 index 0000000000000..78e5cc81c1988 --- /dev/null +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -0,0 +1,941 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Donut Swin Transformer model. + +This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden +states.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_donut_swin import DonutSwinConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "DonutSwinConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] + +DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "naver-clova-ix/donut-base", + # See all Donut Swin models at https://huggingface.co/models?filter=donut +] + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin +class DonutSwinEncoderOutput(ModelOutput): + """ + DonutSwin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin +class DonutSwinModelOutput(ModelOutput): + """ + DonutSwin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size)) + windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1) + return windows + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin +class DonutSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None + ) -> Tuple[torch.Tensor]: + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings +class DonutSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class DonutSwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.swin.modeling_swin.drop_path +def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class DonutSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin +class DonutSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput +class DonutSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin +class DonutSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads): + super().__init__() + self.self = DonutSwinSelfAttention(config, dim, num_heads) + self.output = DonutSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate +class DonutSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput +class DonutSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin +class DonutSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.set_shift_and_window_size(input_resolution) + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = DonutSwinAttention(config, dim, num_heads) + self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DonutSwinIntermediate(config, dim) + self.output = DonutSwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(input_resolution) + + def get_attn_mask(self, height, width): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1)) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.set_shift_and_window_size(input_dimensions) + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, width_pad) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin +class DonutSwinStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + DonutSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(layer_outputs[0], input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin +class DonutSwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + self.layers = nn.ModuleList( + [ + DonutSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, DonutSwinEncoderOutput]: + all_input_dimensions = () + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask + ) + else: + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[1] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DonutSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin +class DonutSwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DonutSwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DonutSwinEncoder): + module.gradient_checkpointing = value + + +SWIN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`DonutSwinConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, +) +class DonutSwinModel(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=DonutSwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return DonutSwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py new file mode 100644 index 0000000000000..1b00d894bd087 --- /dev/null +++ b/src/transformers/models/donut/processing_donut.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Donut. +""" +import re +import warnings +from contextlib import contextmanager + +from ...processing_utils import ProcessorMixin + + +class DonutProcessor(ProcessorMixin): + r""" + Constructs a Donut processor which wraps a Donut feature extractor and an XLMRoBERTa tokenizer into a single + processor. + + [`DonutProcessor`] offers all the functionalities of [`DonutFeatureExtractor`] and + [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. See the [`~DonutProcessor.__call__`] and + [`~DonutProcessor.decode`] for more information. + + Args: + feature_extractor ([`DonutFeatureExtractor`]): + An instance of [`DonutFeatureExtractor`]. The feature extractor is a required input. + tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]): + An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input. + """ + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's + [`~AutoFeatureExtractor.__call__`] and returns its output. If used in the context + [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's + [`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + images = kwargs.pop("images", None) + text = kwargs.pop("text", None) + if len(args) > 0: + images = args[0] + args = args[1:] + + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + if images is not None: + inputs = self.feature_extractor(images, *args, **kwargs) + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the + docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your images inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def token2json(self, tokens, is_inner_value=False): + """ + Convert a (generated) token sequence into an ordered JSON format. + """ + output = dict() + + while tokens: + start_token = re.search(r"", tokens, re.IGNORECASE) + if start_token is None: + break + key = start_token.group(1) + end_token = re.search(rf"", tokens, re.IGNORECASE) + start_token = start_token.group() + if end_token is None: + tokens = tokens.replace(start_token, "") + else: + end_token = end_token.group() + start_token_escaped = re.escape(start_token) + end_token_escaped = re.escape(end_token) + content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) + if content is not None: + content = content.group(1).strip() + if r""): + leaf = leaf.strip() + if leaf in self.tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>": + leaf = leaf[1:-2] # for categorical special tokens + output[key].append(leaf) + if len(output[key]) == 1: + output[key] = output[key][0] + + tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() + if tokens[:6] == r"": # non-leaf nodes + return [output] + self.token2json(tokens[6:], is_inner_value=True) + + if len(output): + return [output] if is_inner_value else output + else: + return [] if is_inner_value else {"text_sequence": tokens} diff --git a/src/transformers/models/vision_encoder_decoder/convert_trocr_unilm_to_pytorch.py b/src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py similarity index 100% rename from src/transformers/models/vision_encoder_decoder/convert_trocr_unilm_to_pytorch.py rename to src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index d636be655af28..96a93ecae942a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1682,6 +1682,23 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class DonutSwinModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DonutSwinPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 30228e022222b..fa30432070a37 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -52,6 +52,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class DonutFeatureExtractor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class DPTFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 990f278b0d506..3c3babd403778 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -105,6 +105,7 @@ def _generate_supported_model_class_names( "deberta", "deberta-v2", "distilbert", + "donut-swin", "electra", "gpt2", "gpt_neo", diff --git a/tests/models/donut/__init__.py b/tests/models/donut/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/donut/test_feature_extraction_donut.py b/tests/models/donut/test_feature_extraction_donut.py new file mode 100644 index 0000000000000..38ccbf2075a9b --- /dev/null +++ b/tests/models/donut/test_feature_extraction_donut.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import DonutFeatureExtractor + + +class DonutFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=[20, 18], + do_thumbnail=True, + do_align_axis=False, + do_pad=True, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_thumbnail = do_thumbnail + self.do_align_axis = do_align_axis + self.do_pad = do_pad + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_thumbnail": self.do_thumbnail, + "do_align_long_axis": self.do_align_axis, + "do_pad": self.do_pad, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + } + + +@require_torch +@require_vision +class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = DonutFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = DonutFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "do_thumbnail")) + self.assertTrue(hasattr(feature_extractor, "do_align_long_axis")) + self.assertTrue(hasattr(feature_extractor, "do_pad")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size[1], + self.feature_extract_tester.size[0], + ), + ) diff --git a/tests/models/donut/test_modeling_donut_swin.py b/tests/models/donut/test_modeling_donut_swin.py new file mode 100644 index 0000000000000..f909d961880a9 --- /dev/null +++ b/tests/models/donut/test_modeling_donut_swin.py @@ -0,0 +1,464 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Donut Swin model. """ + +import collections +import inspect +import os +import pickle +import tempfile +import unittest + +from transformers import DonutSwinConfig +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_torch_fx_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import DonutSwinModel + from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST + +if is_torch_fx_available(): + from transformers.utils.fx import symbolic_trace + + +class DonutSwinModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=2, + num_channels=3, + embed_dim=16, + depths=[1, 2, 1], + num_heads=[2, 2, 4], + window_size=2, + mlp_ratio=2.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + patch_norm=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + is_training=True, + scope=None, + use_labels=True, + type_sequence_label_size=10, + encoder_stride=8, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.patch_norm = patch_norm + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.is_training = is_training + self.scope = scope + self.use_labels = use_labels + self.type_sequence_label_size = type_sequence_label_size + self.encoder_stride = encoder_stride + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return DonutSwinConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + embed_dim=self.embed_dim, + depths=self.depths, + num_heads=self.num_heads, + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + drop_path_rate=self.drop_path_rate, + hidden_act=self.hidden_act, + use_absolute_embeddings=self.use_absolute_embeddings, + path_norm=self.patch_norm, + layer_norm_eps=self.layer_norm_eps, + initializer_range=self.initializer_range, + encoder_stride=self.encoder_stride, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = DonutSwinModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1)) + expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1)) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class DonutSwinModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = (DonutSwinModel,) if is_torch_available() else () + fx_compatible = True + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = DonutSwinModelTester(self) + self.config_tester = ConfigTester(self, config_class=DonutSwinConfig, embed_dim=37) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_inputs_embeds(self): + # DonutSwin does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + expected_num_attentions = len(self.model_tester.depths) + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + window_size_squared = config.window_size**2 + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + # also another +1 for reshaped_hidden_states + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) + + def check_hidden_states_output(self, inputs_dict, config, model_class, image_size): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # DonutSwin has a different seq_length + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + def test_hidden_states_output(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + self.check_hidden_states_output(inputs_dict, config, model_class, image_size) + + def test_hidden_states_output_with_padding(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.patch_size = 3 + + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + + padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) + padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) + + @slow + def test_model_from_pretrained(self): + for model_name in DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = DonutSwinModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if "embeddings" not in name and param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False): + if not is_torch_fx_available() or not self.fx_compatible: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.return_dict = False + + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss) + + try: + if model.config.is_encoder_decoder: + model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward + labels = inputs.get("labels", None) + input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"] + if labels is not None: + input_names.append("labels") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + else: + input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"] + + labels = inputs.get("labels", None) + start_positions = inputs.get("start_positions", None) + end_positions = inputs.get("end_positions", None) + if labels is not None: + input_names.append("labels") + if start_positions is not None: + input_names.append("start_positions") + if end_positions is not None: + input_names.append("end_positions") + + filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} + input_names = list(filtered_inputs.keys()) + + model_output = model(**filtered_inputs) + + traced_model = symbolic_trace(model, input_names) + traced_output = traced_model(**filtered_inputs) + + except RuntimeError as e: + self.fail(f"Couldn't trace module: {e}") + + def flatten_output(output): + flatten = [] + for x in output: + if isinstance(x, (tuple, list)): + flatten += flatten_output(x) + elif not isinstance(x, torch.Tensor): + continue + else: + flatten.append(x) + return flatten + + model_output = flatten_output(model_output) + traced_output = flatten_output(traced_output) + num_outputs = len(model_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], traced_output[i]), + f"traced {i}th output doesn't match model {i}th output for {model_class}", + ) + + # Test that the model can be serialized and restored properly + with tempfile.TemporaryDirectory() as tmp_dir_name: + pkl_file_name = os.path.join(tmp_dir_name, "model.pkl") + try: + with open(pkl_file_name, "wb") as f: + pickle.dump(traced_model, f) + with open(pkl_file_name, "rb") as f: + loaded = pickle.load(f) + except Exception as e: + self.fail(f"Couldn't serialize / deserialize the traced model: {e}") + + loaded_output = loaded(**filtered_inputs) + loaded_output = flatten_output(loaded_output) + + for i in range(num_outputs): + self.assertTrue( + torch.allclose(model_output[i], loaded_output[i]), + f"serialized model {i}th output doesn't match model {i}th output for {model_class}", + ) diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 320cdd6330626..7570888097c53 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -13,14 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import re import tempfile import unittest from datasets import load_dataset from packaging import version -from transformers.testing_utils import require_torch, require_vision, slow, to_2tuple, torch_device +from transformers import DonutProcessor, TrOCRProcessor +from transformers.testing_utils import ( + require_sentencepiece, + require_torch, + require_vision, + slow, + to_2tuple, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask @@ -54,7 +62,7 @@ import PIL from PIL import Image - from transformers import TrOCRProcessor, ViTFeatureExtractor + from transformers import ViTFeatureExtractor @require_torch @@ -654,8 +662,8 @@ def default_processor(self): def test_inference_handwritten(self): model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(torch_device) - ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test") - image = Image.open(ds[0]["file"]).convert("RGB") + dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test") + image = Image.open(dataset[0]["file"]).convert("RGB") processor = self.default_processor pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device) @@ -679,8 +687,8 @@ def test_inference_handwritten(self): def test_inference_printed(self): model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(torch_device) - ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test") - image = Image.open(ds[1]["file"]).convert("RGB") + dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test") + image = Image.open(dataset[1]["file"]).convert("RGB") processor = self.default_processor pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device) @@ -774,3 +782,197 @@ def generate_step(pixel_values): # should produce # ["a cat laying on top of a couch next to another cat"] self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"]) + + +@require_vision +@require_torch +@require_sentencepiece +class DonutModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_docvqa(self): + processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") + model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to( + torch_device + ) + + dataset = load_dataset("hf-internal-testing/example-documents", split="test") + image = dataset[0]["image"] + + pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device) + decoder_input_ids = processor.tokenizer( + "", add_special_tokens=False, return_tensors="pt" + ).input_ids.to(torch_device) + + # step 1: single forward pass + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size([1, 1, 57532]) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device) + self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) + + # step 2: generation + task_prompt = "{user_input}" + question = "When is the coffee break?" + prompt = task_prompt.replace("{user_input}", question) + decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids + decoder_input_ids = decoder_input_ids.to(torch_device) + + outputs = model.generate( + pixel_values, + decoder_input_ids=decoder_input_ids, + max_length=model.decoder.config.max_position_embeddings, + early_stopping=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + use_cache=True, + num_beams=1, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + output_scores=True, + return_dict_in_generate=True, + ) + sequence = processor.batch_decode(outputs.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token + + # verify generated sequence + self.assertEqual( + sequence, " When is the coffee break? 11-14 to 11:39 a.m." + ) + + # verify scores + self.assertEqual(len(outputs.scores), 11) + self.assertTrue( + torch.allclose( + outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4 + ) + ) + + @slow + def test_inference_cordv2(self): + processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") + model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2").to( + torch_device + ) + + dataset = load_dataset("hf-internal-testing/example-documents", split="test") + image = dataset[2]["image"] + + pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device) + decoder_input_ids = processor.tokenizer( + "", add_special_tokens=False, return_tensors="pt" + ).input_ids.to(torch_device) + + # step 1: single forward pass + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device) + self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) + + # step 2: generation + task_prompt = "" + decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids + decoder_input_ids = decoder_input_ids.to(torch_device) + + outputs = model.generate( + pixel_values, + decoder_input_ids=decoder_input_ids, + max_length=model.decoder.config.max_position_embeddings, + early_stopping=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + use_cache=True, + num_beams=1, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + output_scores=True, + return_dict_in_generate=True, + ) + + sequence = processor.batch_decode(outputs.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token + + # verify generated sequence + # fmt: off + expected_sequence = " CINNAMON SUGAR 17,000 1 x 17,000 17,000 17,000 20,000 3,000" # noqa: E231 + # fmt: on + self.assertEqual(sequence, expected_sequence) + + # verify scores + self.assertEqual(len(outputs.scores), 43) + self.assertTrue( + torch.allclose( + outputs.scores[0][0, :3], torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device), atol=1e-4 + ) + ) + + @slow + def test_inference_rvlcdip(self): + processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip") + model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip").to( + torch_device + ) + + dataset = load_dataset("hf-internal-testing/example-documents", split="test") + image = dataset[1]["image"] + + pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device) + + # step 1: single forward pass + decoder_input_ids = processor.tokenizer( + "", add_special_tokens=False, return_tensors="pt" + ).input_ids.to(torch_device) + with torch.no_grad(): + outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device) + self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) + + # step 2: generation + task_prompt = "" + decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids + decoder_input_ids = decoder_input_ids.to(torch_device) + + outputs = model.generate( + pixel_values, + decoder_input_ids=decoder_input_ids, + max_length=model.decoder.config.max_position_embeddings, + early_stopping=True, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + use_cache=True, + num_beams=1, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + output_scores=True, + return_dict_in_generate=True, + ) + + sequence = processor.batch_decode(outputs.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token + + # verify generated sequence + self.assertEqual(sequence, "") + + # verify scores + self.assertEqual(len(outputs.scores), 4) + self.assertTrue( + torch.allclose( + outputs.scores[0][0, :3], torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device), atol=1e-4 + ) + ) diff --git a/utils/check_copies.py b/utils/check_copies.py index e2e0e1a53e433..7d57173654468 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -471,6 +471,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119): "Data2VecAudio": "Data2Vec", "Data2VecText": "Data2Vec", "Data2VecVision": "Data2Vec", + "DonutSwin": "Donut", "Marian": "MarianMT", "OpenAI GPT-2": "GPT-2", "OpenAI GPT": "GPT", diff --git a/utils/check_repo.py b/utils/check_repo.py index d2271e87ebf17..254467113d6cb 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -206,6 +206,7 @@ ("data2vec-text", "data2vec"), ("data2vec-audio", "data2vec"), ("data2vec-vision", "data2vec"), + ("donut-swin", "donut"), ] ) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 1941a7343a6bc..0edda8ae5a4c3 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -8,6 +8,7 @@ docs/source/en/model_doc/t5.mdx docs/source/en/model_doc/t5v1.1.mdx docs/source/en/model_doc/byt5.mdx docs/source/en/model_doc/tapex.mdx +docs/source/en/model_doc/donut.mdx docs/source/en/model_doc/encoder-decoder.mdx src/transformers/generation_utils.py src/transformers/models/albert/modeling_albert.py