diff --git a/README.md b/README.md
index 46a4b07c14cd3..30bc6d870bbf0 100644
--- a/README.md
+++ b/README.md
@@ -286,6 +286,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
+1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
diff --git a/README_ko.md b/README_ko.md
index c63fdca749da8..cc0b790ad76a8 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -242,6 +242,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
+1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER) released with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 0ab06bd96ad99..fe2fa45f71f39 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -266,6 +266,7 @@ conda install -c huggingface transformers
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (来自 Microsoft Research) 伴随论文 [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) 由 Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan 发布。
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (来自 HuggingFace), 伴随论文 [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 同样的方法也应用于压缩 GPT-2 到 [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa 到 [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT 到 [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) 和德语版 DistilBERT。
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (来自 Microsoft Research) 伴随论文 [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) 由 Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei 发布。
+1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (来自 NAVER) 伴随论文 [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) 由 Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park 发布。
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (来自 Facebook) 伴随论文 [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) 由 Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih 发布。
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (来自 Intel Labs) 伴随论文 [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) 由 René Ranftl, Alexey Bochkovskiy, Vladlen Koltun 发布。
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (来自 Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 90f29ad031b8b..4f5a995476149 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -278,6 +278,7 @@ conda install -c huggingface transformers
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German version of DistilBERT.
1. **[DiT](https://huggingface.co/docs/transformers/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
+1. **[Donut](https://huggingface.co/docs/transformers/main/model_doc/donut)** (from NAVER) released with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[DPT](https://huggingface.co/docs/transformers/master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 32ab4c6361d3a..78137d2c8a74c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -427,6 +427,8 @@
title: CLIP
- local: model_doc/data2vec
title: Data2Vec
+ - local: model_doc/donut
+ title: Donut
- local: model_doc/flava
title: FLAVA
- local: model_doc/groupvit
diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx
index 5c0d51d8b7afb..257eba8171ed1 100644
--- a/docs/source/en/index.mdx
+++ b/docs/source/en/index.mdx
@@ -84,6 +84,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
1. **[DialoGPT](model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
1. **[DistilBERT](model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/research_projects/distillation) and a German version of DistilBERT.
1. **[DiT](model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
+1. **[Donut](model_doc/donut)** (from NAVER), released together with the paper [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
1. **[DPR](model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[DPT](master/model_doc/dpt)** (from Intel Labs) released with the paper [Vision Transformers for Dense Prediction](https://arxiv.org/abs/2103.13413) by René Ranftl, Alexey Bochkovskiy, Vladlen Koltun.
1. **[ELECTRA](model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
@@ -224,6 +225,7 @@ Flax), PyTorch, and/or TensorFlow.
| DeiT | ❌ | ❌ | ✅ | ✅ | ❌ |
| DETR | ❌ | ❌ | ✅ | ❌ | ❌ |
| DistilBERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+| DonutSwin | ❌ | ❌ | ✅ | ❌ | ❌ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
| DPT | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
diff --git a/docs/source/en/model_doc/donut.mdx b/docs/source/en/model_doc/donut.mdx
new file mode 100644
index 0000000000000..9c9973be022e7
--- /dev/null
+++ b/docs/source/en/model_doc/donut.mdx
@@ -0,0 +1,214 @@
+
+
+# Donut
+
+## Overview
+
+The Donut model was proposed in [OCR-free Document Understanding Transformer](https://arxiv.org/abs/2111.15664) by
+Geewook Kim, Teakgyu Hong, Moonbin Yim, Jeongyeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park.
+Donut consists of an image Transformer encoder and an autoregressive text Transformer decoder to perform document understanding
+tasks such as document image classification, form understanding and visual question answering.
+
+The abstract from the paper is the following:
+
+*Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.*
+
+
+
+ Donut high-level overview. Taken from the original paper.
+
+This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found
+[here](https://github.com/clovaai/donut).
+
+Tips:
+
+- The quickest way to get started with Donut is by checking the [tutorial
+ notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/donut), which show how to use the model
+ at inference time as well as fine-tuning on custom data.
+- Donut is always used within the [VisionEncoderDecoder](vision-encoder-decoder) framework.
+
+## Inference
+
+Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
+[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
+
+The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and
+[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
+[`DonutProcessor`] wraps [`DonutFeatureExtractor`] and [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]
+into a single instance to both extract the input features and decode the predicted token ids.
+
+- Step-by-step Document Image Classification
+
+```py
+>>> import re
+
+>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
+>>> from datasets import load_dataset
+>>> import torch
+
+>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
+>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
+
+>>> device = "cuda" if torch.cuda.is_available() else "cpu"
+>>> model.to(device) # doctest: +IGNORE_RESULT
+
+>>> # load document image
+>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+>>> image = dataset[1]["image"]
+
+>>> # prepare decoder inputs
+>>> task_prompt = ""
+>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
+
+>>> pixel_values = processor(image, return_tensors="pt").pixel_values
+
+>>> outputs = model.generate(
+... pixel_values.to(device),
+... decoder_input_ids=decoder_input_ids.to(device),
+... max_length=model.decoder.config.max_position_embeddings,
+... early_stopping=True,
+... pad_token_id=processor.tokenizer.pad_token_id,
+... eos_token_id=processor.tokenizer.eos_token_id,
+... use_cache=True,
+... num_beams=1,
+... bad_words_ids=[[processor.tokenizer.unk_token_id]],
+... return_dict_in_generate=True,
+... )
+
+>>> sequence = processor.batch_decode(outputs.sequences)[0]
+>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+>>> print(processor.token2json(sequence))
+{'class': 'advertisement'}
+```
+
+- Step-by-step Document Parsing
+
+```py
+>>> import re
+
+>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
+>>> from datasets import load_dataset
+>>> import torch
+
+>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
+>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
+
+>>> device = "cuda" if torch.cuda.is_available() else "cpu"
+>>> model.to(device) # doctest: +IGNORE_RESULT
+
+>>> # load document image
+>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+>>> image = dataset[2]["image"]
+
+>>> # prepare decoder inputs
+>>> task_prompt = ""
+>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
+
+>>> pixel_values = processor(image, return_tensors="pt").pixel_values
+
+>>> outputs = model.generate(
+... pixel_values.to(device),
+... decoder_input_ids=decoder_input_ids.to(device),
+... max_length=model.decoder.config.max_position_embeddings,
+... early_stopping=True,
+... pad_token_id=processor.tokenizer.pad_token_id,
+... eos_token_id=processor.tokenizer.eos_token_id,
+... use_cache=True,
+... num_beams=1,
+... bad_words_ids=[[processor.tokenizer.unk_token_id]],
+... return_dict_in_generate=True,
+... )
+
+>>> sequence = processor.batch_decode(outputs.sequences)[0]
+>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+>>> print(processor.token2json(sequence))
+{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}
+```
+
+- Step-by-step Document Visual Question Answering (DocVQA)
+
+```py
+>>> import re
+
+>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
+>>> from datasets import load_dataset
+>>> import torch
+
+>>> processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
+>>> model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
+
+>>> device = "cuda" if torch.cuda.is_available() else "cpu"
+>>> model.to(device) # doctest: +IGNORE_RESULT
+
+>>> # load document image from the DocVQA dataset
+>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+>>> image = dataset[0]["image"]
+
+>>> # prepare decoder inputs
+>>> task_prompt = "{user_input}"
+>>> question = "When is the coffee break?"
+>>> prompt = task_prompt.replace("{user_input}", question)
+>>> decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
+
+>>> pixel_values = processor(image, return_tensors="pt").pixel_values
+
+>>> outputs = model.generate(
+... pixel_values.to(device),
+... decoder_input_ids=decoder_input_ids.to(device),
+... max_length=model.decoder.config.max_position_embeddings,
+... early_stopping=True,
+... pad_token_id=processor.tokenizer.pad_token_id,
+... eos_token_id=processor.tokenizer.eos_token_id,
+... use_cache=True,
+... num_beams=1,
+... bad_words_ids=[[processor.tokenizer.unk_token_id]],
+... return_dict_in_generate=True,
+... )
+
+>>> sequence = processor.batch_decode(outputs.sequences)[0]
+>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+>>> print(processor.token2json(sequence))
+{'question': 'When is the coffee break?', 'answer': '11-14 to 11:39 a.m.'}
+```
+
+See the [model hub](https://huggingface.co/models?filter=donut) to look for Donut checkpoints.
+
+## Training
+
+We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/donut).
+
+## DonutSwinConfig
+
+[[autodoc]] DonutSwinConfig
+
+## DonutFeatureExtractor
+
+[[autodoc]] DonutFeatureExtractor
+ - __call__
+
+## DonutProcessor
+
+[[autodoc]] DonutProcessor
+ - __call__
+ - from_pretrained
+ - save_pretrained
+ - batch_decode
+ - decode
+
+## DonutSwinModel
+
+[[autodoc]] DonutSwinModel
+ - forward
\ No newline at end of file
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 2f53db07f078f..d6444e0844ff5 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -190,6 +190,7 @@
"models.dialogpt": [],
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"],
"models.dit": [],
+ "models.donut": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutProcessor", "DonutSwinConfig"],
"models.dpr": [
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"DPRConfig",
@@ -641,6 +642,7 @@
_import_structure["models.convnext"].append("ConvNextFeatureExtractor")
_import_structure["models.deit"].append("DeiTFeatureExtractor")
_import_structure["models.detr"].append("DetrFeatureExtractor")
+ _import_structure["models.donut"].append("DonutFeatureExtractor")
_import_structure["models.dpt"].append("DPTFeatureExtractor")
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaProcessor"])
_import_structure["models.glpn"].append("GLPNFeatureExtractor")
@@ -1099,6 +1101,13 @@
"DistilBertPreTrainedModel",
]
)
+ _import_structure["models.donut"].extend(
+ [
+ "DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "DonutSwinModel",
+ "DonutSwinPreTrainedModel",
+ ]
+ )
_import_structure["models.dpr"].extend(
[
"DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2984,6 +2993,7 @@
from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer
+ from .models.donut import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutProcessor, DonutSwinConfig
from .models.dpr import (
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
DPRConfig,
@@ -3375,6 +3385,7 @@
from .models.convnext import ConvNextFeatureExtractor
from .models.deit import DeiTFeatureExtractor
from .models.detr import DetrFeatureExtractor
+ from .models.donut import DonutFeatureExtractor
from .models.dpt import DPTFeatureExtractor
from .models.flava import FlavaFeatureExtractor, FlavaProcessor
from .models.glpn import GLPNFeatureExtractor
@@ -3761,6 +3772,7 @@
DistilBertModel,
DistilBertPreTrainedModel,
)
+ from .models.donut import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, DonutSwinModel, DonutSwinPreTrainedModel
from .models.dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py
index dd7bb326993d3..e5a395341c003 100644
--- a/src/transformers/image_utils.py
+++ b/src/transformers/image_utils.py
@@ -376,3 +376,25 @@ def flip_channel_order(self, image):
image = self.to_numpy_array(image)
return image[::-1, :, :]
+
+ def rotate(self, image, angle, resample=PIL.Image.NEAREST, expand=0, center=None, translate=None, fillcolor=None):
+ """
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
+ counter clockwise around its centre.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
+ rotating.
+
+ Returns:
+ image: A rotated `PIL.Image.Image`.
+ """
+ self._ensure_format_supported(image)
+
+ if not isinstance(image, PIL.Image.Image):
+ image = self.to_pil_image(image)
+
+ return image.rotate(
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
+ )
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 11887db91f839..fdf315b2257d8 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -52,6 +52,7 @@
dialogpt,
distilbert,
dit,
+ donut,
dpr,
dpt,
electra,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index c65a2762a0002..c9e6156a3843d 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -56,6 +56,7 @@
("deit", "DeiTConfig"),
("detr", "DetrConfig"),
("distilbert", "DistilBertConfig"),
+ ("donut-swin", "DonutSwinConfig"),
("dpr", "DPRConfig"),
("dpt", "DPTConfig"),
("electra", "ElectraConfig"),
@@ -181,6 +182,7 @@
("deit", "DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("detr", "DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("distilbert", "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+ ("donut-swin", "DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpr", "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("electra", "ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -304,6 +306,8 @@
("dialogpt", "DialoGPT"),
("distilbert", "DistilBERT"),
("dit", "DiT"),
+ ("donut", "Donut"),
+ ("donut-swin", "DonutSwin"),
("dpr", "DPR"),
("dpt", "DPT"),
("electra", "ELECTRA"),
@@ -420,6 +424,7 @@
("data2vec-audio", "data2vec"),
("data2vec-text", "data2vec"),
("data2vec-vision", "data2vec"),
+ ("donut-swin", "donut"),
]
)
diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py
index db581d03d8fb7..5c5f86d040c8f 100644
--- a/src/transformers/models/auto/feature_extraction_auto.py
+++ b/src/transformers/models/auto/feature_extraction_auto.py
@@ -46,6 +46,7 @@
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
+ ("donut", "DonutFeatureExtractor"),
("dpt", "DPTFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index bd4774c245b07..0e026cb48d0c0 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -56,6 +56,7 @@
("deit", "DeiTModel"),
("detr", "DetrModel"),
("distilbert", "DistilBertModel"),
+ ("donut-swin", "DonutSwinModel"),
("dpr", "DPRQuestionEncoder"),
("dpt", "DPTModel"),
("electra", "ElectraModel"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index aed7b4b976137..c6f4fd98316a4 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -38,6 +38,7 @@
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("clip", "CLIPProcessor"),
+ ("donut", "DonutProcessor"),
("flava", "FlavaProcessor"),
("groupvit", "CLIPProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
diff --git a/src/transformers/models/donut/__init__.py b/src/transformers/models/donut/__init__.py
new file mode 100644
index 0000000000000..a01f6b11a9a99
--- /dev/null
+++ b/src/transformers/models/donut/__init__.py
@@ -0,0 +1,76 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig"],
+ "processing_donut": ["DonutProcessor"],
+}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_donut_swin"] = [
+ "DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "DonutSwinModel",
+ "DonutSwinPreTrainedModel",
+ ]
+
+try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["feature_extraction_donut"] = ["DonutFeatureExtractor"]
+
+
+if TYPE_CHECKING:
+ from .configuration_donut_swin import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutSwinConfig
+ from .processing_donut import DonutProcessor
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_donut_swin import (
+ DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
+ DonutSwinModel,
+ DonutSwinPreTrainedModel,
+ )
+
+ try:
+ if not is_vision_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .feature_extraction_donut import DonutFeatureExtractor
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/donut/configuration_donut_swin.py b/src/transformers/models/donut/configuration_donut_swin.py
new file mode 100644
index 0000000000000..d3316bdc79f68
--- /dev/null
+++ b/src/transformers/models/donut/configuration_donut_swin.py
@@ -0,0 +1,140 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Donut Swin Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "naver-clova-ix/donut-base": "https://huggingface.co/naver-clova-ix/donut-base/resolve/main/config.json",
+ # See all Donut models at https://huggingface.co/models?filter=donut-swin
+}
+
+
+class DonutSwinConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DonutSwinModel`]. It is used to instantiate a
+ Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Donut
+ [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ depths (`list(int)`, *optional*, defaults to [2, 2, 6, 2]):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to [3, 6, 12, 24]):
+ Number of attention heads in each layer of the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 7):
+ Size of windows.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to True):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ use_absolute_embeddings (`bool`, *optional*, defaults to False):
+ Whether or not to add absolute position embeddings to the patch embeddings.
+ patch_norm (`bool`, *optional*, defaults to True):
+ Whether or not to add layer normalization after patch embedding.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import DonutSwinConfig, DonutSwinModel
+
+ >>> # Initializing a Donut naver-clova-ix/donut-base style configuration
+ >>> configuration = DonutSwinConfig()
+
+ >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration
+ >>> model = DonutSwinModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "donut-swin"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.path_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
diff --git a/src/transformers/models/donut/convert_donut_to_pytorch.py b/src/transformers/models/donut/convert_donut_to_pytorch.py
new file mode 100644
index 0000000000000..507f10cb776cf
--- /dev/null
+++ b/src/transformers/models/donut/convert_donut_to_pytorch.py
@@ -0,0 +1,234 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Donut checkpoints using the original `donut-python` library. URL: https://github.com/clovaai/donut"""
+
+import argparse
+
+import torch
+from datasets import load_dataset
+
+from donut import DonutModel
+from transformers import (
+ DonutFeatureExtractor,
+ DonutProcessor,
+ DonutSwinConfig,
+ DonutSwinModel,
+ MBartConfig,
+ MBartForCausalLM,
+ VisionEncoderDecoderModel,
+ XLMRobertaTokenizerFast,
+)
+
+
+def get_configs(model):
+ original_config = model.config
+
+ encoder_config = DonutSwinConfig(
+ image_size=original_config.input_size,
+ patch_size=4,
+ depths=original_config.encoder_layer,
+ num_heads=[4, 8, 16, 32],
+ window_size=original_config.window_size,
+ embed_dim=128,
+ )
+ decoder_config = MBartConfig(
+ is_decoder=True,
+ is_encoder_decoder=False,
+ add_cross_attention=True,
+ decoder_layers=original_config.decoder_layer,
+ max_position_embeddings=original_config.max_position_embeddings,
+ vocab_size=len(
+ model.decoder.tokenizer
+ ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json)
+ scale_embedding=True,
+ add_final_layer_norm=True,
+ )
+
+ return encoder_config, decoder_config
+
+
+def rename_key(name):
+ if "encoder.model" in name:
+ name = name.replace("encoder.model", "encoder")
+ if "decoder.model" in name:
+ name = name.replace("decoder.model", "decoder")
+ if "patch_embed.proj" in name:
+ name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
+ if "patch_embed.norm" in name:
+ name = name.replace("patch_embed.norm", "embeddings.norm")
+ if name.startswith("encoder"):
+ if "layers" in name:
+ name = "encoder." + name
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "attn" in name and "mask" not in name:
+ name = name.replace("attn", "attention.self")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+
+ if name == "encoder.norm.weight":
+ name = "encoder.layernorm.weight"
+ if name == "encoder.norm.bias":
+ name = "encoder.layernorm.bias"
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, model):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if "qkv" in key:
+ key_split = key.split(".")
+ layer_num = int(key_split[3])
+ block_num = int(key_split[5])
+ dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
+
+ if "weight" in key:
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
+ ] = val[:dim, :]
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"
+ ] = val[dim : dim * 2, :]
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
+ ] = val[-dim:, :]
+ else:
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"
+ ] = val[:dim]
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"
+ ] = val[dim : dim * 2]
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"
+ ] = val[-dim:]
+ elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]:
+ # HuggingFace implementation doesn't use attn_mask buffer
+ # and model doesn't use final LayerNorms for the encoder
+ pass
+ else:
+ orig_state_dict[rename_key(key)] = val
+
+ return orig_state_dict
+
+
+def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
+ # load original model
+ original_model = DonutModel.from_pretrained(model_name).eval()
+
+ # load HuggingFace model
+ encoder_config, decoder_config = get_configs(original_model)
+ encoder = DonutSwinModel(encoder_config)
+ decoder = MBartForCausalLM(decoder_config)
+ model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
+ model.eval()
+
+ state_dict = original_model.state_dict()
+ new_state_dict = convert_state_dict(state_dict, model)
+ model.load_state_dict(new_state_dict)
+
+ # verify results on scanned document
+ dataset = load_dataset("hf-internal-testing/example-documents")
+ image = dataset["test"][0]["image"].convert("RGB")
+
+ tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name, from_slow=True)
+ feature_extractor = DonutFeatureExtractor(
+ do_align_long_axis=original_model.config.align_long_axis, size=original_model.config.input_size[::-1]
+ )
+ processor = DonutProcessor(feature_extractor, tokenizer)
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ if model_name == "naver-clova-ix/donut-base-finetuned-docvqa":
+ task_prompt = "{user_input}"
+ question = "When is the coffee break?"
+ task_prompt = task_prompt.replace("{user_input}", question)
+ elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip":
+ task_prompt = ""
+ elif model_name in [
+ "naver-clova-ix/donut-base-finetuned-cord-v1",
+ "naver-clova-ix/donut-base-finetuned-cord-v1-2560",
+ ]:
+ task_prompt = ""
+ elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
+ task_prompt = "s_cord-v2>"
+ elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket":
+ task_prompt = ""
+ elif model_name in ["naver-clova-ix/donut-proto", "naver-clova-ix/donut-base"]:
+ # use a random prompt
+ task_prompt = "hello world"
+ else:
+ raise ValueError("Model name not supported")
+ prompt_tensors = original_model.decoder.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")[
+ "input_ids"
+ ]
+
+ original_patch_embed = original_model.encoder.model.patch_embed(pixel_values)
+ patch_embeddings, _ = model.encoder.embeddings(pixel_values)
+ assert torch.allclose(original_patch_embed, patch_embeddings, atol=1e-3)
+
+ # verify encoder hidden states
+ original_last_hidden_state = original_model.encoder(pixel_values)
+ last_hidden_state = model.encoder(pixel_values).last_hidden_state
+ assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2)
+
+ # verify decoder hidden states
+ original_logits = original_model(pixel_values, prompt_tensors, None).logits
+ logits = model(pixel_values, decoder_input_ids=prompt_tensors).logits
+ assert torch.allclose(original_logits, logits, atol=1e-3)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model")
+ processor.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="naver-clova-ix/donut-base-finetuned-docvqa",
+ required=False,
+ type=str,
+ help="Name of the original model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ required=False,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the converted model and processor to the 🤗 hub.",
+ )
+
+ args = parser.parse_args()
+ convert_donut_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/src/transformers/models/donut/feature_extraction_donut.py b/src/transformers/models/donut/feature_extraction_donut.py
new file mode 100644
index 0000000000000..09bf3a6ad1c15
--- /dev/null
+++ b/src/transformers/models/donut/feature_extraction_donut.py
@@ -0,0 +1,208 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Donut."""
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image, ImageOps
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageFeatureExtractionMixin,
+ ImageInput,
+ is_torch_tensor,
+)
+from ...utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a Donut feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the shorter edge of the input to the minimum value of a certain `size`.
+ size (`Tuple(int)`, *optional*, defaults to [1920, 2560]):
+ Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width,
+ height). Only has an effect if `do_resize` is set to `True`.
+ resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_thumbnail (`bool`, *optional*, defaults to `True`):
+ Whether to thumbnail the input to the given `size`.
+ do_align_long_axis (`bool`, *optional*, defaults to `False`):
+ Whether to rotate the input if the height is greater than width.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad the input to `size`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=[1920, 2560],
+ resample=Image.BILINEAR,
+ do_thumbnail=True,
+ do_align_long_axis=False,
+ do_pad=True,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_thumbnail = do_thumbnail
+ self.do_align_long_axis = do_align_long_axis
+ self.do_pad = do_pad
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def rotate_image(self, image, size):
+ if not isinstance(image, Image.Image):
+ image = self.to_pil_image(image)
+
+ if (size[1] > size[0] and image.width > image.height) or (size[1] < size[0] and image.width < image.height):
+ image = self.rotate(image, angle=-90, expand=True)
+
+ return image
+
+ def thumbnail(self, image, size):
+ if not isinstance(image, Image.Image):
+ image = self.to_pil_image(image)
+
+ image.thumbnail((size[0], size[1]))
+
+ return image
+
+ def pad(self, image: Image.Image, size: Tuple[int, int], random_padding: bool = False) -> Image.Image:
+ delta_width = size[0] - image.width
+ delta_height = size[1] - image.height
+
+ if random_padding:
+ pad_width = np.random.randint(low=0, high=delta_width + 1)
+ pad_height = np.random.randint(low=0, high=delta_height + 1)
+ else:
+ pad_width = delta_width // 2
+ pad_height = delta_height // 2
+
+ padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height)
+ return ImageOps.expand(image, padding)
+
+ def __call__(
+ self,
+ images: ImageInput,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ random_padding=False,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ random_padding (`bool`, *optional*, defaults to `False`):
+ Whether to randomly pad the input to `size`.
+
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (rotating + resizing + thumbnailing + padding + normalization)
+ if self.do_align_long_axis:
+ images = [self.rotate_image(image, self.size) for image in images]
+ if self.do_resize and self.size is not None:
+ images = [
+ self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False)
+ for image in images
+ ]
+ if self.do_thumbnail and self.size is not None:
+ images = [self.thumbnail(image=image, size=self.size) for image in images]
+ if self.do_pad and self.size is not None:
+ images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ # return as BatchFeature
+ data = {"pixel_values": images}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py
new file mode 100644
index 0000000000000..78e5cc81c1988
--- /dev/null
+++ b/src/transformers/models/donut/modeling_donut_swin.py
@@ -0,0 +1,941 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Donut Swin Transformer model.
+
+This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
+states."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_donut_swin import DonutSwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DonutSwinConfig"
+_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "naver-clova-ix/donut-base",
+ # See all Donut Swin models at https://huggingface.co/models?filter=donut
+]
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
+class DonutSwinEncoderOutput(ModelOutput):
+ """
+ DonutSwin encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
+class DonutSwinModelOutput(ModelOutput):
+ """
+ DonutSwin model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: torch.FloatTensor = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = input_feature.shape
+ input_feature = input_feature.view(
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+ )
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+ """
+ Merges windows to produce higher resolution features.
+ """
+ batch_size = math.floor(windows.shape[0] / (height * width / window_size / window_size))
+ windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
+class DonutSwinEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config, use_mask_token=False):
+ super().__init__()
+
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+ else:
+ self.position_embeddings = None
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
+ ) -> Tuple[torch.Tensor]:
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+ batch_size, seq_len, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
+class DonutSwinPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def maybe_pad(self, pixel_values, height, width):
+ if width % self.patch_size[1] != 0:
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+ embeddings = self.projection(pixel_values)
+ _, _, height, width = embeddings.shape
+ output_dimensions = (height, width)
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class DonutSwinPatchMerging(nn.Module):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def maybe_pad(self, input_feature, height, width):
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
+ input_feature = nn.functional.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, dim, num_channels = input_feature.shape
+
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature)
+ input_feature = self.reduction(input_feature)
+
+ return input_feature
+
+
+# Copied from transformers.models.swin.modeling_swin.drop_path
+def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath
+class DonutSwinDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
+class DonutSwinSelfAttention(nn.Module):
+ def __init__(self, config, dim, num_heads):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ window_size = config.window_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ batch_size, dim, num_channels = hidden_states.shape
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+ relative_position_bias = relative_position_bias.view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ )
+
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+ attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
+ mask_shape = attention_mask.shape[0]
+ attention_scores = attention_scores.view(
+ batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+ )
+ attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
+class DonutSwinSelfOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
+class DonutSwinAttention(nn.Module):
+ def __init__(self, config, dim, num_heads):
+ super().__init__()
+ self.self = DonutSwinSelfAttention(config, dim, num_heads)
+ self.output = DonutSwinSelfOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
+class DonutSwinIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput
+class DonutSwinOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
+class DonutSwinLayer(nn.Module):
+ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.set_shift_and_window_size(input_resolution)
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = DonutSwinAttention(config, dim, num_heads)
+ self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = DonutSwinIntermediate(config, dim)
+ self.output = DonutSwinOutput(config, dim)
+
+ def set_shift_and_window_size(self, input_resolution):
+ if min(input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(input_resolution)
+
+ def get_attn_mask(self, height, width):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, height, width, 1))
+ height_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ width_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ img_mask[:, height_slice, width_slice, :] = count
+ count += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states, height, width):
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ self.set_shift_and_window_size(input_dimensions)
+ height, width = input_dimensions
+ batch_size, _, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+ attn_mask = self.get_attn_mask(height_pad, width_pad)
+ if attn_mask is not None:
+ attn_mask = attn_mask.to(hidden_states_windows.device)
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+ hidden_states = shortcut + self.drop_path(attention_windows)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.output(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
+class DonutSwinStage(nn.Module):
+ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.blocks = nn.ModuleList(
+ [
+ DonutSwinLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(layer_outputs[0], input_dimensions)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
+class DonutSwinEncoder(nn.Module):
+ def __init__(self, config, grid_size):
+ super().__init__()
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
+ self.layers = nn.ModuleList(
+ [
+ DonutSwinStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+ )
+ for i_layer in range(self.num_layers)
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, DonutSwinEncoderOutput]:
+ all_input_dimensions = ()
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ output_dimensions = layer_outputs[1]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+ all_input_dimensions += (input_dimensions,)
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return DonutSwinEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
+class DonutSwinPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DonutSwinConfig
+ base_model_prefix = "swin"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, DonutSwinEncoder):
+ module.gradient_checkpointing = value
+
+
+SWIN_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DonutSwinConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`AutoFeatureExtractor.__call__`] for details.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.",
+ SWIN_START_DOCSTRING,
+)
+class DonutSwinModel(DonutSwinPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+ super().__init__(config)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
+
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_FEAT_EXTRACTOR_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=DonutSwinModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DonutSwinModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return DonutSwinModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py
new file mode 100644
index 0000000000000..1b00d894bd087
--- /dev/null
+++ b/src/transformers/models/donut/processing_donut.py
@@ -0,0 +1,156 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Donut.
+"""
+import re
+import warnings
+from contextlib import contextmanager
+
+from ...processing_utils import ProcessorMixin
+
+
+class DonutProcessor(ProcessorMixin):
+ r"""
+ Constructs a Donut processor which wraps a Donut feature extractor and an XLMRoBERTa tokenizer into a single
+ processor.
+
+ [`DonutProcessor`] offers all the functionalities of [`DonutFeatureExtractor`] and
+ [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. See the [`~DonutProcessor.__call__`] and
+ [`~DonutProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor ([`DonutFeatureExtractor`]):
+ An instance of [`DonutFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]):
+ An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
+ """
+ feature_extractor_class = "AutoFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def __call__(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's
+ [`~AutoFeatureExtractor.__call__`] and returns its output. If used in the context
+ [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's
+ [`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ images = kwargs.pop("images", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ images = args[0]
+ args = args[1:]
+
+ if images is None and text is None:
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
+
+ if images is not None:
+ inputs = self.feature_extractor(images, *args, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif images is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
+ """
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your images inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def token2json(self, tokens, is_inner_value=False):
+ """
+ Convert a (generated) token sequence into an ordered JSON format.
+ """
+ output = dict()
+
+ while tokens:
+ start_token = re.search(r"", tokens, re.IGNORECASE)
+ if start_token is None:
+ break
+ key = start_token.group(1)
+ end_token = re.search(rf"", tokens, re.IGNORECASE)
+ start_token = start_token.group()
+ if end_token is None:
+ tokens = tokens.replace(start_token, "")
+ else:
+ end_token = end_token.group()
+ start_token_escaped = re.escape(start_token)
+ end_token_escaped = re.escape(end_token)
+ content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
+ if content is not None:
+ content = content.group(1).strip()
+ if r""):
+ leaf = leaf.strip()
+ if leaf in self.tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>":
+ leaf = leaf[1:-2] # for categorical special tokens
+ output[key].append(leaf)
+ if len(output[key]) == 1:
+ output[key] = output[key][0]
+
+ tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
+ if tokens[:6] == r"": # non-leaf nodes
+ return [output] + self.token2json(tokens[6:], is_inner_value=True)
+
+ if len(output):
+ return [output] if is_inner_value else output
+ else:
+ return [] if is_inner_value else {"text_sequence": tokens}
diff --git a/src/transformers/models/vision_encoder_decoder/convert_trocr_unilm_to_pytorch.py b/src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py
similarity index 100%
rename from src/transformers/models/vision_encoder_decoder/convert_trocr_unilm_to_pytorch.py
rename to src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index d636be655af28..96a93ecae942a 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -1682,6 +1682,23 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class DonutSwinModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class DonutSwinPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = None
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index 30228e022222b..fa30432070a37 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -52,6 +52,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
+class DonutFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class DPTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 990f278b0d506..3c3babd403778 100644
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -105,6 +105,7 @@ def _generate_supported_model_class_names(
"deberta",
"deberta-v2",
"distilbert",
+ "donut-swin",
"electra",
"gpt2",
"gpt_neo",
diff --git a/tests/models/donut/__init__.py b/tests/models/donut/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/models/donut/test_feature_extraction_donut.py b/tests/models/donut/test_feature_extraction_donut.py
new file mode 100644
index 0000000000000..38ccbf2075a9b
--- /dev/null
+++ b/tests/models/donut/test_feature_extraction_donut.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import DonutFeatureExtractor
+
+
+class DonutFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=[20, 18],
+ do_thumbnail=True,
+ do_align_axis=False,
+ do_pad=True,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.do_thumbnail = do_thumbnail
+ self.do_align_axis = do_align_axis
+ self.do_pad = do_pad
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_thumbnail": self.do_thumbnail,
+ "do_align_long_axis": self.do_align_axis,
+ "do_pad": self.do_pad,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ }
+
+
+@require_torch
+@require_vision
+class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = DonutFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = DonutFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "do_thumbnail"))
+ self.assertTrue(hasattr(feature_extractor, "do_align_long_axis"))
+ self.assertTrue(hasattr(feature_extractor, "do_pad"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ 1,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ self.feature_extract_tester.size[1],
+ self.feature_extract_tester.size[0],
+ ),
+ )
diff --git a/tests/models/donut/test_modeling_donut_swin.py b/tests/models/donut/test_modeling_donut_swin.py
new file mode 100644
index 0000000000000..f909d961880a9
--- /dev/null
+++ b/tests/models/donut/test_modeling_donut_swin.py
@@ -0,0 +1,464 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch Donut Swin model. """
+
+import collections
+import inspect
+import os
+import pickle
+import tempfile
+import unittest
+
+from transformers import DonutSwinConfig
+from transformers.testing_utils import require_torch, slow, torch_device
+from transformers.utils import is_torch_available, is_torch_fx_available
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import DonutSwinModel
+ from transformers.models.donut.modeling_donut_swin import DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_torch_fx_available():
+ from transformers.utils.fx import symbolic_trace
+
+
+class DonutSwinModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=32,
+ patch_size=2,
+ num_channels=3,
+ embed_dim=16,
+ depths=[1, 2, 1],
+ num_heads=[2, 2, 4],
+ window_size=2,
+ mlp_ratio=2.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ patch_norm=True,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ is_training=True,
+ scope=None,
+ use_labels=True,
+ type_sequence_label_size=10,
+ encoder_stride=8,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.patch_norm = patch_norm
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ self.is_training = is_training
+ self.scope = scope
+ self.use_labels = use_labels
+ self.type_sequence_label_size = type_sequence_label_size
+ self.encoder_stride = encoder_stride
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return DonutSwinConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ embed_dim=self.embed_dim,
+ depths=self.depths,
+ num_heads=self.num_heads,
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ drop_path_rate=self.drop_path_rate,
+ hidden_act=self.hidden_act,
+ use_absolute_embeddings=self.use_absolute_embeddings,
+ path_norm=self.patch_norm,
+ layer_norm_eps=self.layer_norm_eps,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = DonutSwinModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+
+ expected_seq_len = ((config.image_size // config.patch_size) ** 2) // (4 ** (len(config.depths) - 1))
+ expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
+
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values,
+ labels,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class DonutSwinModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
+ fx_compatible = True
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = DonutSwinModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=DonutSwinConfig, embed_dim=37)
+
+ def test_config(self):
+ self.create_and_test_config_common_properties()
+ self.config_tester.create_and_test_config_to_json_string()
+ self.config_tester.create_and_test_config_to_json_file()
+ self.config_tester.create_and_test_config_from_and_save_pretrained()
+ self.config_tester.create_and_test_config_with_num_labels()
+ self.config_tester.check_config_can_be_init_without_params()
+ self.config_tester.check_config_arguments_init()
+
+ def create_and_test_config_common_properties(self):
+ return
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_inputs_embeds(self):
+ # DonutSwin does not use inputs_embeds
+ pass
+
+ def test_model_common_attributes(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_forward_signature(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ signature = inspect.signature(model.forward)
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
+ arg_names = [*signature.parameters.keys()]
+
+ expected_arg_names = ["pixel_values"]
+ self.assertListEqual(arg_names[:1], expected_arg_names)
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ expected_num_attentions = len(self.model_tester.depths)
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ window_size_squared = config.window_size**2
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ self.assertEqual(len(attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ if hasattr(self.model_tester, "num_hidden_states_types"):
+ added_hidden_states = self.model_tester.num_hidden_states_types
+ else:
+ # also another +1 for reshaped_hidden_states
+ added_hidden_states = 2
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertEqual(len(self_attentions), expected_num_attentions)
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
+ )
+
+ def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
+ )
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ # DonutSwin has a different seq_length
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ reshaped_hidden_states = outputs.reshaped_hidden_states
+ self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
+
+ batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
+ reshaped_hidden_states = (
+ reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
+ )
+ self.assertListEqual(
+ list(reshaped_hidden_states.shape[-2:]),
+ [num_patches, self.model_tester.embed_dim],
+ )
+
+ def test_hidden_states_output(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
+
+ def test_hidden_states_output_with_padding(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.patch_size = 3
+
+ image_size = (
+ self.model_tester.image_size
+ if isinstance(self.model_tester.image_size, collections.abc.Iterable)
+ else (self.model_tester.image_size, self.model_tester.image_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+
+ padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
+ padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_hidden_states"] = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+ self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in DONUT_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = DonutSwinModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+ def test_initialization(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if "embeddings" not in name and param.requires_grad:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
+ if not is_torch_fx_available() or not self.fx_compatible:
+ return
+
+ configs_no_init = _config_zero_init(config) # To be sure we have no Nan
+ configs_no_init.return_dict = False
+
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ model.to(torch_device)
+ model.eval()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=output_loss)
+
+ try:
+ if model.config.is_encoder_decoder:
+ model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
+ labels = inputs.get("labels", None)
+ input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
+ if labels is not None:
+ input_names.append("labels")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+ else:
+ input_names = ["input_ids", "attention_mask", "token_type_ids", "pixel_values"]
+
+ labels = inputs.get("labels", None)
+ start_positions = inputs.get("start_positions", None)
+ end_positions = inputs.get("end_positions", None)
+ if labels is not None:
+ input_names.append("labels")
+ if start_positions is not None:
+ input_names.append("start_positions")
+ if end_positions is not None:
+ input_names.append("end_positions")
+
+ filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
+ input_names = list(filtered_inputs.keys())
+
+ model_output = model(**filtered_inputs)
+
+ traced_model = symbolic_trace(model, input_names)
+ traced_output = traced_model(**filtered_inputs)
+
+ except RuntimeError as e:
+ self.fail(f"Couldn't trace module: {e}")
+
+ def flatten_output(output):
+ flatten = []
+ for x in output:
+ if isinstance(x, (tuple, list)):
+ flatten += flatten_output(x)
+ elif not isinstance(x, torch.Tensor):
+ continue
+ else:
+ flatten.append(x)
+ return flatten
+
+ model_output = flatten_output(model_output)
+ traced_output = flatten_output(traced_output)
+ num_outputs = len(model_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], traced_output[i]),
+ f"traced {i}th output doesn't match model {i}th output for {model_class}",
+ )
+
+ # Test that the model can be serialized and restored properly
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pkl_file_name = os.path.join(tmp_dir_name, "model.pkl")
+ try:
+ with open(pkl_file_name, "wb") as f:
+ pickle.dump(traced_model, f)
+ with open(pkl_file_name, "rb") as f:
+ loaded = pickle.load(f)
+ except Exception as e:
+ self.fail(f"Couldn't serialize / deserialize the traced model: {e}")
+
+ loaded_output = loaded(**filtered_inputs)
+ loaded_output = flatten_output(loaded_output)
+
+ for i in range(num_outputs):
+ self.assertTrue(
+ torch.allclose(model_output[i], loaded_output[i]),
+ f"serialized model {i}th output doesn't match model {i}th output for {model_class}",
+ )
diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
index 320cdd6330626..7570888097c53 100644
--- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
@@ -13,14 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import re
import tempfile
import unittest
from datasets import load_dataset
from packaging import version
-from transformers.testing_utils import require_torch, require_vision, slow, to_2tuple, torch_device
+from transformers import DonutProcessor, TrOCRProcessor
+from transformers.testing_utils import (
+ require_sentencepiece,
+ require_torch,
+ require_vision,
+ slow,
+ to_2tuple,
+ torch_device,
+)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
@@ -54,7 +62,7 @@
import PIL
from PIL import Image
- from transformers import TrOCRProcessor, ViTFeatureExtractor
+ from transformers import ViTFeatureExtractor
@require_torch
@@ -654,8 +662,8 @@ def default_processor(self):
def test_inference_handwritten(self):
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(torch_device)
- ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
- image = Image.open(ds[0]["file"]).convert("RGB")
+ dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
+ image = Image.open(dataset[0]["file"]).convert("RGB")
processor = self.default_processor
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
@@ -679,8 +687,8 @@ def test_inference_handwritten(self):
def test_inference_printed(self):
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(torch_device)
- ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
- image = Image.open(ds[1]["file"]).convert("RGB")
+ dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
+ image = Image.open(dataset[1]["file"]).convert("RGB")
processor = self.default_processor
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
@@ -774,3 +782,197 @@ def generate_step(pixel_values):
# should produce
# ["a cat laying on top of a couch next to another cat"]
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
+
+
+@require_vision
+@require_torch
+@require_sentencepiece
+class DonutModelIntegrationTest(unittest.TestCase):
+ @slow
+ def test_inference_docvqa(self):
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to(
+ torch_device
+ )
+
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+ image = dataset[0]["image"]
+
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
+ decoder_input_ids = processor.tokenizer(
+ "", add_special_tokens=False, return_tensors="pt"
+ ).input_ids.to(torch_device)
+
+ # step 1: single forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
+ logits = outputs.logits
+
+ # verify the logits
+ expected_shape = torch.Size([1, 1, 57532])
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device)
+ self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
+
+ # step 2: generation
+ task_prompt = "{user_input}"
+ question = "When is the coffee break?"
+ prompt = task_prompt.replace("{user_input}", question)
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
+ decoder_input_ids = decoder_input_ids.to(torch_device)
+
+ outputs = model.generate(
+ pixel_values,
+ decoder_input_ids=decoder_input_ids,
+ max_length=model.decoder.config.max_position_embeddings,
+ early_stopping=True,
+ pad_token_id=processor.tokenizer.pad_token_id,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ use_cache=True,
+ num_beams=1,
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
+ output_scores=True,
+ return_dict_in_generate=True,
+ )
+ sequence = processor.batch_decode(outputs.sequences)[0]
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+
+ # verify generated sequence
+ self.assertEqual(
+ sequence, " When is the coffee break? 11-14 to 11:39 a.m."
+ )
+
+ # verify scores
+ self.assertEqual(len(outputs.scores), 11)
+ self.assertTrue(
+ torch.allclose(
+ outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4
+ )
+ )
+
+ @slow
+ def test_inference_cordv2(self):
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2").to(
+ torch_device
+ )
+
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+ image = dataset[2]["image"]
+
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
+ decoder_input_ids = processor.tokenizer(
+ "", add_special_tokens=False, return_tensors="pt"
+ ).input_ids.to(torch_device)
+
+ # step 1: single forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
+ logits = outputs.logits
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device)
+ self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
+
+ # step 2: generation
+ task_prompt = ""
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
+ decoder_input_ids = decoder_input_ids.to(torch_device)
+
+ outputs = model.generate(
+ pixel_values,
+ decoder_input_ids=decoder_input_ids,
+ max_length=model.decoder.config.max_position_embeddings,
+ early_stopping=True,
+ pad_token_id=processor.tokenizer.pad_token_id,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ use_cache=True,
+ num_beams=1,
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
+ output_scores=True,
+ return_dict_in_generate=True,
+ )
+
+ sequence = processor.batch_decode(outputs.sequences)[0]
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+
+ # verify generated sequence
+ # fmt: off
+ expected_sequence = " CINNAMON SUGAR 17,000 1 x 17,000 17,000 17,000 20,000 3,000" # noqa: E231
+ # fmt: on
+ self.assertEqual(sequence, expected_sequence)
+
+ # verify scores
+ self.assertEqual(len(outputs.scores), 43)
+ self.assertTrue(
+ torch.allclose(
+ outputs.scores[0][0, :3], torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device), atol=1e-4
+ )
+ )
+
+ @slow
+ def test_inference_rvlcdip(self):
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip").to(
+ torch_device
+ )
+
+ dataset = load_dataset("hf-internal-testing/example-documents", split="test")
+ image = dataset[1]["image"]
+
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
+
+ # step 1: single forward pass
+ decoder_input_ids = processor.tokenizer(
+ "", add_special_tokens=False, return_tensors="pt"
+ ).input_ids.to(torch_device)
+ with torch.no_grad():
+ outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
+ logits = outputs.logits
+
+ # verify the logits
+ expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device)
+ self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
+
+ # step 2: generation
+ task_prompt = ""
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
+ decoder_input_ids = decoder_input_ids.to(torch_device)
+
+ outputs = model.generate(
+ pixel_values,
+ decoder_input_ids=decoder_input_ids,
+ max_length=model.decoder.config.max_position_embeddings,
+ early_stopping=True,
+ pad_token_id=processor.tokenizer.pad_token_id,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ use_cache=True,
+ num_beams=1,
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
+ output_scores=True,
+ return_dict_in_generate=True,
+ )
+
+ sequence = processor.batch_decode(outputs.sequences)[0]
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
+
+ # verify generated sequence
+ self.assertEqual(sequence, "")
+
+ # verify scores
+ self.assertEqual(len(outputs.scores), 4)
+ self.assertTrue(
+ torch.allclose(
+ outputs.scores[0][0, :3], torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device), atol=1e-4
+ )
+ )
diff --git a/utils/check_copies.py b/utils/check_copies.py
index e2e0e1a53e433..7d57173654468 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -471,6 +471,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
"Data2VecAudio": "Data2Vec",
"Data2VecText": "Data2Vec",
"Data2VecVision": "Data2Vec",
+ "DonutSwin": "Donut",
"Marian": "MarianMT",
"OpenAI GPT-2": "GPT-2",
"OpenAI GPT": "GPT",
diff --git a/utils/check_repo.py b/utils/check_repo.py
index d2271e87ebf17..254467113d6cb 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -206,6 +206,7 @@
("data2vec-text", "data2vec"),
("data2vec-audio", "data2vec"),
("data2vec-vision", "data2vec"),
+ ("donut-swin", "donut"),
]
)
diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt
index 1941a7343a6bc..0edda8ae5a4c3 100644
--- a/utils/documentation_tests.txt
+++ b/utils/documentation_tests.txt
@@ -8,6 +8,7 @@ docs/source/en/model_doc/t5.mdx
docs/source/en/model_doc/t5v1.1.mdx
docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx
+docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation_utils.py
src/transformers/models/albert/modeling_albert.py