diff --git a/.circleci/config.yml b/.circleci/config.yml index 6df8ad8e89f91..969b26fd9795f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -890,7 +890,7 @@ jobs: - run: make deps_table_check_updated - run: python utils/tests_fetcher.py --sanity_check - run_tests_layoutlmv2: + run_tests_layoutlmv2_and_v3: working_directory: ~/transformers docker: - image: circleci/python:3.7 @@ -921,7 +921,7 @@ jobs: path: ~/transformers/test_preparation.txt - run: | if [ -f test_list.txt ]; then - python -m pytest -n 1 tests/models/*layoutlmv2* --dist=loadfile -s --make-reports=tests_layoutlmv2 --durations=100 + python -m pytest -n 1 tests/models/*layoutlmv* --dist=loadfile -s --make-reports=tests_layoutlmv2_and_v3 --durations=100 fi - store_artifacts: path: ~/transformers/tests_output.txt @@ -982,7 +982,7 @@ workflows: - run_tests_pipelines_tf - run_tests_onnxruntime - run_tests_hub - - run_tests_layoutlmv2 + - run_tests_layoutlmv2_and_v3 nightly: triggers: - schedule: diff --git a/README.md b/README.md index c046393239091..dffc4f63c77b7 100644 --- a/README.md +++ b/README.md @@ -279,6 +279,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. +1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. 1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. diff --git a/README_ko.md b/README_ko.md index aab7a7c4bc2d6..1c1060ae5eb19 100644 --- a/README_ko.md +++ b/README_ko.md @@ -258,6 +258,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. +1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. 1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. diff --git a/README_zh-hans.md b/README_zh-hans.md index 7031fd3570bf1..386731b7d2154 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -282,6 +282,7 @@ conda install -c huggingface transformers 1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (来自 OpenAI) 伴随论文 [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) 由 Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever 发布。 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) 由 Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou 发布。 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) 由 Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou 发布。 +1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) 由 Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei 发布。 1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (来自 Microsoft Research Asia) 伴随论文 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) 由 Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei 发布。 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 5971ab404917f..5a4a7c6b72d60 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -294,6 +294,7 @@ conda install -c huggingface transformers 1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. 1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. +1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. 1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. 1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index cb67299cff4d4..c1f32a4b65f0a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -242,6 +242,8 @@ title: LayoutLM - local: model_doc/layoutlmv2 title: LayoutLMV2 + - local: model_doc/layoutlmv3 + title: LayoutLMV3 - local: model_doc/layoutxlm title: LayoutXLM - local: model_doc/led diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 35d2a99440e79..460063edab93d 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -100,6 +100,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret 1. **[ImageGPT](model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. 1. **[LayoutLM](model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LayoutLMv2](model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. +1. **[LayoutLMv3](model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. 1. **[LayoutXLM](model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. 1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. 1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. @@ -220,6 +221,7 @@ Flax), PyTorch, and/or TensorFlow. | ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ | | LayoutLM | ✅ | ✅ | ✅ | ✅ | ❌ | | LayoutLMv2 | ✅ | ✅ | ✅ | ❌ | ❌ | +| LayoutLMv3 | ✅ | ✅ | ✅ | ❌ | ❌ | | LED | ✅ | ✅ | ✅ | ✅ | ❌ | | Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | | LUKE | ✅ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/layoutlmv3.mdx b/docs/source/en/model_doc/layoutlmv3.mdx new file mode 100644 index 0000000000000..8463dbce7a0d5 --- /dev/null +++ b/docs/source/en/model_doc/layoutlmv3.mdx @@ -0,0 +1,85 @@ + + +# LayoutLMv3 + +## Overview + +The LayoutLMv3 model was proposed in [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. +LayoutLMv3 simplifies [LayoutLMv2](layoutlmv2) by using patch embeddings (as in [ViT](vit)) instead of leveraging a CNN backbone, and pre-trains the model on 3 objectives: masked language modeling (MLM), masked image modeling (MIM) +and word-patch alignment (WPA). + +The abstract from the paper is the following: + +*Self-supervised pre-training techniques have achieved remarkable progress in Document AI. Most multimodal pre-trained models use a masked language modeling objective to learn bidirectional representations on the text modality, but they differ in pre-training objectives for the image modality. This discrepancy adds difficulty to multimodal representation learning. In this paper, we propose LayoutLMv3 to pre-train multimodal Transformers for Document AI with unified text and image masking. Additionally, LayoutLMv3 is pre-trained with a word-patch alignment objective to learn cross-modal alignment by predicting whether the corresponding image patch of a text word is masked. The simple unified architecture and training objectives make LayoutLMv3 a general-purpose pre-trained model for both text-centric and image-centric Document AI tasks. Experimental results show that LayoutLMv3 achieves state-of-the-art performance not only in text-centric tasks, including form understanding, receipt understanding, and document visual question answering, but also in image-centric tasks such as document image classification and document layout analysis.* + +Tips: + +- In terms of data processing, LayoutLMv3 is identical to its predecessor [LayoutLMv2](layoutlmv2), except that: + - images need to be resized and normalized with channels in regular RGB format. LayoutLMv2 on the other hand normalizes the images internally and expects the channels in BGR format. + - text is tokenized using byte-pair encoding (BPE), as opposed to WordPiece. + Due to these differences in data preprocessing, one can use [`LayoutLMv3Processor`] which internally combines a [`LayoutLMv3FeatureExtractor`] (for the image modality) and a [`LayoutLMv3Tokenizer`]/[`LayoutLMv3TokenizerFast`] (for the text modality) to prepare all data for the model. +- Regarding usage of [`LayoutLMv3Processor`], we refer to the [usage guide](layoutlmv2#usage-LayoutLMv2Processor) of its predecessor. +- Demo notebooks for LayoutLMv3 can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/LayoutLMv3). + + + + LayoutLMv3 architecture. Taken from the original paper. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/microsoft/unilm/tree/master/layoutlmv3). + + +## LayoutLMv3Config + +[[autodoc]] LayoutLMv3Config + +## LayoutLMv3FeatureExtractor + +[[autodoc]] LayoutLMv3FeatureExtractor + - __call__ + +## LayoutLMv3Tokenizer + +[[autodoc]] LayoutLMv3Tokenizer + - __call__ + - save_vocabulary + +## LayoutLMv3TokenizerFast + +[[autodoc]] LayoutLMv3TokenizerFast + - __call__ + +## LayoutLMv3Processor + +[[autodoc]] LayoutLMv3Processor + - __call__ + +## LayoutLMv3Model + +[[autodoc]] LayoutLMv3Model + - forward + +## LayoutLMv3ForSequenceClassification + +[[autodoc]] LayoutLMv3ForSequenceClassification + - forward + +## LayoutLMv3ForTokenClassification + +[[autodoc]] LayoutLMv3ForTokenClassification + - forward + +## LayoutLMv3ForQuestionAnswering + +[[autodoc]] LayoutLMv3ForQuestionAnswering + - forward diff --git a/examples/research_projects/layoutlmv3/README.md b/examples/research_projects/layoutlmv3/README.md new file mode 100644 index 0000000000000..17bf4bb67cd90 --- /dev/null +++ b/examples/research_projects/layoutlmv3/README.md @@ -0,0 +1,69 @@ + + +# Token classification with LayoutLMv3 (PyTorch version) + +This directory contains a script, `run_funsd_cord.py`, that can be used to fine-tune (or evaluate) LayoutLMv3 on form understanding datasets, such as [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [CORD](https://github.com/clovaai/cord). + +The script `run_funsd_cord.py` leverages the 🤗 Datasets library and the Trainer API. You can easily customize it to your needs. + +## Fine-tuning on FUNSD + +Fine-tuning LayoutLMv3 for token classification on [FUNSD](https://guillaumejaume.github.io/FUNSD/) can be done as follows: + +```bash +python run_funsd_cord.py \ + --model_name_or_path microsoft/layoutlmv3-base \ + --dataset_name funsd \ + --output_dir layoutlmv3-test \ + --do_train \ + --do_eval \ + --max_steps 1000 \ + --evaluation_strategy steps \ + --eval_steps 100 \ + --learning_rate 1e-5 \ + --load_best_model_at_end \ + --metric_for_best_model "eval_f1" \ + --push_to_hub \ + --push_to_hub°model_id layoutlmv3-finetuned-funsd +``` + +👀 The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd. By specifying the `push_to_hub` flag, the model gets uploaded automatically to the hub (regularly), together with a model card, which includes metrics such as precision, recall and F1. Note that you can easily update the model card, as it's just a README file of the respective repo on the hub. + +There's also the "Training metrics" [tab](https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd/tensorboard), which shows Tensorboard logs over the course of training. Pretty neat, huh? + +## Fine-tuning on CORD + +Fine-tuning LayoutLMv3 for token classification on [CORD](https://github.com/clovaai/cord) can be done as follows: + +```bash +python run_funsd_cord.py \ + --model_name_or_path microsoft/layoutlmv3-base \ + --dataset_name cord \ + --output_dir layoutlmv3-test \ + --do_train \ + --do_eval \ + --max_steps 1000 \ + --evaluation_strategy steps \ + --eval_steps 100 \ + --learning_rate 5e-5 \ + --load_best_model_at_end \ + --metric_for_best_model "eval_f1" \ + --push_to_hub \ + --push_to_hub°model_id layoutlmv3-finetuned-cord +``` + +👀 The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-cord. Note that a model card gets generated automatically in case you specify the `push_to_hub` flag. \ No newline at end of file diff --git a/examples/research_projects/layoutlmv3/requirements.txt b/examples/research_projects/layoutlmv3/requirements.txt new file mode 100644 index 0000000000000..504a8cc9870fa --- /dev/null +++ b/examples/research_projects/layoutlmv3/requirements.txt @@ -0,0 +1,2 @@ +datasets +seqeval \ No newline at end of file diff --git a/examples/research_projects/layoutlmv3/run_funsd_cord.py b/examples/research_projects/layoutlmv3/run_funsd_cord.py new file mode 100644 index 0000000000000..66be61dffccf2 --- /dev/null +++ b/examples/research_projects/layoutlmv3/run_funsd_cord.py @@ -0,0 +1,533 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2022 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning LayoutLMv3 for token classification on FUNSD or CORD. +""" +# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as +# comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import numpy as np +from datasets import ClassLabel, load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoProcessor, + HfArgumentParser, + Trainer, + TrainingArguments, + set_seed, +) +from transformers.data.data_collator import default_data_collator +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.19.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="microsoft/layoutlmv3-base", + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + processor_name: Optional[str] = field( + default=None, metadata={"help": "Name or path to the processor files if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) + dataset_name: Optional[str] = field( + default="nielsr/funsd-layoutlmv3", + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a csv or JSON file)."} + ) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, + ) + text_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} + ) + label_column_name: Optional[str] = field( + default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_seq_length: int = field( + default=512, + metadata={ + "help": ( + "The maximum total input sequence length after tokenization. If set, sequences longer " + "than this will be truncated, sequences shorter will be padded." + ) + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + ) + }, + ) + label_all_tokens: bool = field( + default=False, + metadata={ + "help": ( + "Whether to put the label for one word on all tokens of generated by that word or just on the " + "one (in which case the other tokens will have a padding index)." + ) + }, + ) + return_entity_level_metrics: bool = field( + default=False, + metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + self.task_name = self.task_name.lower() + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name == "funsd": + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + "nielsr/funsd-layoutlmv3", + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + elif data_args.dataset_name == "cord": + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + "nielsr/cord-layoutlmv3", + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + raise ValueError("This script only supports either FUNSD or CORD out-of-the-box.") + + if training_args.do_train: + column_names = dataset["train"].column_names + features = dataset["train"].features + else: + column_names = dataset["test"].column_names + features = dataset["test"].features + + image_column_name = "image" + text_column_name = "words" if "words" in column_names else "tokens" + boxes_column_name = "bboxes" + label_column_name = ( + f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1] + ) + + remove_columns = column_names + + # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the + # unique labels. + def get_label_list(labels): + unique_labels = set() + for label in labels: + unique_labels = unique_labels | set(label) + label_list = list(unique_labels) + label_list.sort() + return label_list + + # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. + # Otherwise, we have to get the list of labels manually. + if isinstance(features[label_column_name].feature, ClassLabel): + label_list = features[label_column_name].feature.names + # No need to convert the labels since they are already ints. + id2label = {k: v for k, v in enumerate(label_list)} + label2id = {v: k for k, v in enumerate(label_list)} + else: + label_list = get_label_list(datasets["train"][label_column_name]) + id2label = {k: v for k, v in enumerate(label_list)} + label2id = {v: k for k, v in enumerate(label_list)} + num_labels = len(label_list) + + # Load pretrained model and processor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + processor = AutoProcessor.from_pretrained( + model_args.processor_name if model_args.processor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=True, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + add_prefix_space=True, + apply_ocr=False, + ) + + model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Set the correspondences label/ID inside the model config + model.config.label2id = label2id + model.config.id2label = id2label + + # Preprocessing the dataset + # The processor does everything for us (prepare the image using LayoutLMv3FeatureExtractor + # and prepare the words, boxes and word-level labels using LayoutLMv3TokenizerFast) + def prepare_examples(examples): + images = examples[image_column_name] + words = examples[text_column_name] + boxes = examples[boxes_column_name] + word_labels = examples[label_column_name] + + encoding = processor( + images, + words, + boxes=boxes, + word_labels=word_labels, + truncation=True, + padding="max_length", + max_length=data_args.max_seq_length, + ) + + return encoding + + if training_args.do_train: + if "train" not in dataset: + raise ValueError("--do_train requires a train dataset") + train_dataset = dataset["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + prepare_examples, + batched=True, + remove_columns=remove_columns, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_eval: + validation_name = "test" + if validation_name not in dataset: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = dataset[validation_name] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_dataset.map( + prepare_examples, + batched=True, + remove_columns=remove_columns, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_predict: + if "test" not in datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = datasets["test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_dataset.map( + prepare_examples, + batched=True, + remove_columns=remove_columns, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + # Metrics + metric = load_metric("seqeval") + + def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = metric.compute(predictions=true_predictions, references=true_labels) + if data_args.return_entity_level_metrics: + # Unpack nested dictionaries + final_results = {} + for key, value in results.items(): + if isinstance(value, dict): + for n, v in value.items(): + final_results[f"{key}_{n}"] = v + else: + final_results[key] = value + return final_results + else: + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=processor, + data_collator=default_data_collator, + compute_metrics=compute_metrics, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() # Saves the tokenizer too for easy upload + + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate() + + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + logger.info("*** Predict ***") + + predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + # Save predictions + output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt") + if trainer.is_world_process_zero(): + with open(output_predictions_file, "w") as writer: + for prediction in true_predictions: + writer.write(" ".join(prediction) + "\n") + + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aff37abbec6aa..1953087d22554 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -226,6 +226,13 @@ "LayoutLMv2Processor", "LayoutLMv2Tokenizer", ], + "models.layoutlmv3": [ + "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", + "LayoutLMv3Config", + "LayoutLMv3FeatureExtractor", + "LayoutLMv3Processor", + "LayoutLMv3Tokenizer", + ], "models.layoutxlm": ["LayoutXLMProcessor"], "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], "models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], @@ -504,6 +511,7 @@ _import_structure["models.herbert"].append("HerbertTokenizerFast") _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast") + _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast") _import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast") _import_structure["models.led"].append("LEDTokenizerFast") _import_structure["models.longformer"].append("LongformerTokenizerFast") @@ -590,8 +598,7 @@ _import_structure["models.glpn"].append("GLPNFeatureExtractor") _import_structure["models.imagegpt"].append("ImageGPTFeatureExtractor") _import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor") - _import_structure["models.layoutlmv2"].append("LayoutLMv2Processor") - _import_structure["models.layoutxlm"].append("LayoutXLMProcessor") + _import_structure["models.layoutlmv3"].append("LayoutLMv3FeatureExtractor") _import_structure["models.maskformer"].append("MaskFormerFeatureExtractor") _import_structure["models.perceiver"].append("PerceiverFeatureExtractor") _import_structure["models.poolformer"].append("PoolFormerFeatureExtractor") @@ -1199,6 +1206,16 @@ "LayoutLMv2PreTrainedModel", ] ) + _import_structure["models.layoutlmv3"].extend( + [ + "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST", + "LayoutLMv3ForQuestionAnswering", + "LayoutLMv3ForSequenceClassification", + "LayoutLMv3ForTokenClassification", + "LayoutLMv3Model", + "LayoutLMv3PreTrainedModel", + ] + ) _import_structure["models.led"].extend( [ "LED_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -2759,6 +2776,13 @@ LayoutLMv2Processor, LayoutLMv2Tokenizer, ) + from .models.layoutlmv3 import ( + LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, + LayoutLMv3Config, + LayoutLMv3FeatureExtractor, + LayoutLMv3Processor, + LayoutLMv3Tokenizer, + ) from .models.layoutxlm import LayoutXLMProcessor from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer @@ -3004,6 +3028,7 @@ from .models.herbert import HerbertTokenizerFast from .models.layoutlm import LayoutLMTokenizerFast from .models.layoutlmv2 import LayoutLMv2TokenizerFast + from .models.layoutlmv3 import LayoutLMv3TokenizerFast from .models.layoutxlm import LayoutXLMTokenizerFast from .models.led import LEDTokenizerFast from .models.longformer import LongformerTokenizerFast @@ -3069,8 +3094,8 @@ from .models.flava import FlavaFeatureExtractor, FlavaProcessor from .models.glpn import GLPNFeatureExtractor from .models.imagegpt import ImageGPTFeatureExtractor - from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor - from .models.layoutxlm import LayoutXLMProcessor + from .models.layoutlmv2 import LayoutLMv2FeatureExtractor + from .models.layoutlmv3 import LayoutLMv3FeatureExtractor from .models.maskformer import MaskFormerFeatureExtractor from .models.perceiver import PerceiverFeatureExtractor from .models.poolformer import PoolFormerFeatureExtractor @@ -3581,6 +3606,14 @@ LayoutLMv2Model, LayoutLMv2PreTrainedModel, ) + from .models.layoutlmv3 import ( + LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST, + LayoutLMv3ForQuestionAnswering, + LayoutLMv3ForSequenceClassification, + LayoutLMv3ForTokenClassification, + LayoutLMv3Model, + LayoutLMv3PreTrainedModel, + ) from .models.led import ( LED_PRETRAINED_MODEL_ARCHIVE_LIST, LEDForConditionalGeneration, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 1feb8dd5fb3fc..f3915676b6dc8 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1023,6 +1023,7 @@ def post_processor(self): "HerbertTokenizer": HerbertConverter, "LayoutLMTokenizer": BertConverter, "LayoutLMv2Tokenizer": BertConverter, + "LayoutLMv3Tokenizer": RobertaConverter, "LayoutXLMTokenizer": XLMRobertaConverter, "LongformerTokenizer": RobertaConverter, "LEDTokenizer": RobertaConverter, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 66910e3e0a53b..a2fc8e55e254e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -69,6 +69,7 @@ imagegpt, layoutlm, layoutlmv2, + layoutlmv3, layoutxlm, led, longformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aa4b64fa7015c..e2154c4888f03 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -73,6 +73,7 @@ ("imagegpt", "ImageGPTConfig"), ("layoutlm", "LayoutLMConfig"), ("layoutlmv2", "LayoutLMv2Config"), + ("layoutlmv3", "LayoutLMv3Config"), ("led", "LEDConfig"), ("longformer", "LongformerConfig"), ("luke", "LukeConfig"), @@ -183,6 +184,7 @@ ("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("layoutlmv3", "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -294,6 +296,7 @@ ("imagegpt", "ImageGPT"), ("layoutlm", "LayoutLM"), ("layoutlmv2", "LayoutLMv2"), + ("layoutlmv3", "LayoutLMv3"), ("layoutxlm", "LayoutXLM"), ("led", "LED"), ("longformer", "Longformer"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 5ba7b1228544d..13ab60c38db45 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -51,6 +51,7 @@ ("glpn", "GLPNFeatureExtractor"), ("hubert", "Wav2Vec2FeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"), + ("layoutlmv3", "LayoutLMv3FeatureExtractor"), ("maskformer", "MaskFormerFeatureExtractor"), ("perceiver", "PerceiverFeatureExtractor"), ("poolformer", "PoolFormerFeatureExtractor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1e62a4ab8ed20..aeb2e2930c65b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -72,6 +72,7 @@ ("imagegpt", "ImageGPTModel"), ("layoutlm", "LayoutLMModel"), ("layoutlmv2", "LayoutLMv2Model"), + ("layoutlmv3", "LayoutLMv3Model"), ("led", "LEDModel"), ("longformer", "LongformerModel"), ("luke", "LukeModel"), @@ -457,6 +458,7 @@ ("ibert", "IBertForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), + ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("mbart", "MBartForSequenceClassification"), @@ -505,6 +507,7 @@ ("gptj", "GPTJForQuestionAnswering"), ("ibert", "IBertForQuestionAnswering"), ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("led", "LEDForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"), @@ -556,6 +559,7 @@ ("ibert", "IBertForTokenClassification"), ("layoutlm", "LayoutLMForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"), + ("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("longformer", "LongformerForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index b21d29636dc9d..9eb84ef8b7b12 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -40,6 +40,7 @@ ("clip", "CLIPProcessor"), ("flava", "FLAVAProcessor"), ("layoutlmv2", "LayoutLMv2Processor"), + ("layoutlmv3", "LayoutLMv3Processor"), ("layoutxlm", "LayoutXLMProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 8c1ceeee2cdd5..117b92103f322 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -131,6 +131,7 @@ ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), + ("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/layoutlmv2/__init__.py b/src/transformers/models/layoutlmv2/__init__.py index 73bcb3e5b82cb..beaacb815843d 100644 --- a/src/transformers/models/layoutlmv2/__init__.py +++ b/src/transformers/models/layoutlmv2/__init__.py @@ -29,6 +29,7 @@ _import_structure = { "configuration_layoutlmv2": ["LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv2Config"], + "processing_layoutlmv2": ["LayoutLMv2Processor"], "tokenization_layoutlmv2": ["LayoutLMv2Tokenizer"], } @@ -47,7 +48,6 @@ pass else: _import_structure["feature_extraction_layoutlmv2"] = ["LayoutLMv2FeatureExtractor"] - _import_structure["processing_layoutlmv2"] = ["LayoutLMv2Processor"] try: if not is_torch_available(): @@ -67,6 +67,7 @@ if TYPE_CHECKING: from .configuration_layoutlmv2 import LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv2Config + from .processing_layoutlmv2 import LayoutLMv2Processor from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer try: @@ -84,7 +85,6 @@ pass else: from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor - from .processing_layoutlmv2 import LayoutLMv2Processor try: if not is_torch_available(): diff --git a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py index 7b8ad938b1ccd..57f0b78aed1bb 100644 --- a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py @@ -144,3 +144,17 @@ def get_overflowing_images(self, images, overflow_to_sample_mapping): ) return images_with_overflow + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py index f00420e640f42..d2cf0b2f3dcee 100644 --- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py @@ -499,15 +499,17 @@ def _is_valid_text_input(t): is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) words = text if text_pair is None else text_pair - assert boxes is not None, "You must provide corresponding bounding boxes" + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") if is_batched: - assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples" + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") for words_example, boxes_example in zip(words, boxes): - assert len(words_example) == len( - boxes_example - ), "You must provide as many words as there are bounding boxes" + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") else: - assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes" + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") if is_batched: if text_pair is not None and len(text) != len(text_pair): diff --git a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py index 27a9849548f71..b61cf5ef7633a 100644 --- a/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py +++ b/src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py @@ -260,15 +260,17 @@ def _is_valid_text_input(t): is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) words = text if text_pair is None else text_pair - assert boxes is not None, "You must provide corresponding bounding boxes" + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") if is_batched: - assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples" + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") for words_example, boxes_example in zip(words, boxes): - assert len(words_example) == len( - boxes_example - ), "You must provide as many words as there are bounding boxes" + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") else: - assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes" + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") if is_batched: if text_pair is not None and len(text) != len(text_pair): diff --git a/src/transformers/models/layoutlmv3/__init__.py b/src/transformers/models/layoutlmv3/__init__.py new file mode 100644 index 0000000000000..a7d104040bd6d --- /dev/null +++ b/src/transformers/models/layoutlmv3/__init__.py @@ -0,0 +1,107 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tokenizers_available, + is_torch_available, + is_vision_available, +) + + +_import_structure = { + "configuration_layoutlmv3": ["LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv3Config"], + "processing_layoutlmv3": ["LayoutLMv3Processor"], + "tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"], +} + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_layoutlmv3_fast"] = ["LayoutLMv3TokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_layoutlmv3"] = [ + "LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST", + "LayoutLMv3ForQuestionAnswering", + "LayoutLMv3ForSequenceClassification", + "LayoutLMv3ForTokenClassification", + "LayoutLMv3Model", + "LayoutLMv3PreTrainedModel", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["feature_extraction_layoutlmv3"] = ["LayoutLMv3FeatureExtractor"] + + +if TYPE_CHECKING: + from .configuration_layoutlmv3 import LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv3Config + from .processing_layoutlmv3 import LayoutLMv3Processor + from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_layoutlmv3 import ( + LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST, + LayoutLMv3ForQuestionAnswering, + LayoutLMv3ForSequenceClassification, + LayoutLMv3ForTokenClassification, + LayoutLMv3Model, + LayoutLMv3PreTrainedModel, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .feature_extraction_layoutlmv3 import LayoutLMv3FeatureExtractor + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py new file mode 100644 index 0000000000000..ebde107947a14 --- /dev/null +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LayoutLMv3 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json", +} + + +class LayoutLMv3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LayoutLMv3Model`]. It is used to instantiate an + LayoutLMv3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LayoutLMv3 + [microsoft/layoutlmv3-base](https://huggingface.co/microsoft/layoutlmv3-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the LayoutLMv3 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`LayoutLMv3Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv3Model`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + max_2d_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum value that the 2D position embedding might ever be used with. Typically set this to something + large just in case (e.g., 1024). + coordinate_size (`int`, *optional*, defaults to `128`): + Dimension of the coordinate embeddings. + shape_size (`int`, *optional*, defaults to `128`): + Dimension of the width and height embeddings. + has_relative_attention_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use a relative attention bias in the self-attention mechanism. + rel_pos_bins (`int`, *optional*, defaults to 32): + The number of relative position bins to be used in the self-attention mechanism. + max_rel_pos (`int`, *optional*, defaults to 128): + The maximum number of relative positions to be used in the self-attention mechanism. + max_rel_2d_pos (`int`, *optional*, defaults to 256): + The maximum number of relative 2D positions in the self-attention mechanism. + rel_2d_pos_bins (`int`, *optional*, defaults to 64): + The number of 2D relative position bins in the self-attention mechanism. + has_spatial_attention_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use a spatial attention bias in the self-attention mechanism. + visual_embed (`bool`, *optional*, defaults to `True`): + Whether or not to add patch embeddings. + input_size (`int`, *optional*, defaults to `224`): + The size (resolution) of the images. + num_channels (`int`, *optional*, defaults to `3`): + The number of channels of the images. + patch_size (`int`, *optional*, defaults to `16`) + The size (resolution) of the patches. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Example: + + ```python + >>> from transformers import LayoutLMv3Model, LayoutLMv3Config + + >>> # Initializing a LayoutLMv3 microsoft/layoutlmv3-base style configuration + >>> configuration = LayoutLMv3Config() + + >>> # Initializing a model from the microsoft/layoutlmv3-base style configuration + >>> model = LayoutLMv3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "layoutlmv3" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + max_2d_position_embeddings=1024, + coordinate_size=128, + shape_size=128, + has_relative_attention_bias=True, + rel_pos_bins=32, + max_rel_pos=128, + rel_2d_pos_bins=64, + max_rel_2d_pos=256, + has_spatial_attention_bias=True, + text_embed=True, + visual_embed=True, + input_size=224, + num_channels=3, + patch_size=16, + classifier_dropout=None, + **kwargs + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + initializer_range=initializer_range, + layer_norm_eps=layer_norm_eps, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + self.max_2d_position_embeddings = max_2d_position_embeddings + self.coordinate_size = coordinate_size + self.shape_size = shape_size + self.has_relative_attention_bias = has_relative_attention_bias + self.rel_pos_bins = rel_pos_bins + self.max_rel_pos = max_rel_pos + self.has_spatial_attention_bias = has_spatial_attention_bias + self.rel_2d_pos_bins = rel_2d_pos_bins + self.max_rel_2d_pos = max_rel_2d_pos + self.text_embed = text_embed + self.visual_embed = visual_embed + self.input_size = input_size + self.num_channels = num_channels + self.patch_size = patch_size + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py new file mode 100644 index 0000000000000..6f2d54529ba97 --- /dev/null +++ b/src/transformers/models/layoutlmv3/feature_extraction_layoutlmv3.py @@ -0,0 +1,242 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for LayoutLMv3. +""" + +from typing import List, Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import TensorType, is_pytesseract_available, logging, requires_backends + + +# soft dependency +if is_pytesseract_available(): + import pytesseract + +logger = logging.get_logger(__name__) + +ImageInput = Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa +] + + +def normalize_box(box, width, height): + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def apply_tesseract(image: Image.Image, lang: Optional[str]): + """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" + + # apply OCR + data = pytesseract.image_to_data(image, lang=lang, output_type="dict") + words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] + + # filter empty words and corresponding coordinates + irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()] + words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] + left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices] + top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices] + width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices] + height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices] + + # turn coordinates into (left, top, left+width, top+height) format + actual_boxes = [] + for x, y, w, h in zip(left, top, width, height): + actual_box = [x, y, x + w, y + h] + actual_boxes.append(actual_box) + + image_width, image_height = image.size + + # finally, normalize the bounding boxes + normalized_boxes = [] + for box in actual_boxes: + normalized_boxes.append(normalize_box(box, image_width, image_height)) + + assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes" + + return words, normalized_boxes + + +class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to + apply OCR on them in order to get a list of words and normalized bounding boxes. + + This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most + of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int` or `Tuple(int)`, *optional*, defaults to 224): + Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an + integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is + set to `True`. + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, + `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect + if `do_resize` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + apply_ocr (`bool`, *optional*, defaults to `True`): + Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes. + ocr_lang (`Optional[str]`, *optional*): + The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is + used. + + + + LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood. + + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=224, + resample=Image.BILINEAR, + do_normalize=True, + image_mean=None, + image_std=None, + apply_ocr=True, + ocr_lang=None, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.apply_ocr = apply_ocr + self.ocr_lang = ocr_lang + + def __call__( + self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + - **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was + initialized with `apply_ocr` set to `True`). + - **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size + (only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`). + + Examples: + + ```python + >>> from transformers import LayoutLMv3FeatureExtractor + >>> from PIL import Image + + >>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB") + + >>> # option 1: with apply_ocr=True (default) + >>> feature_extractor = LayoutLMv3FeatureExtractor() + >>> encoding = feature_extractor(image, return_tensors="pt") + >>> print(encoding.keys()) + >>> # dict_keys(['pixel_values', 'words', 'boxes']) + + >>> # option 2: with apply_ocr=False + >>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) + >>> encoding = feature_extractor(image, return_tensors="pt") + >>> print(encoding.keys()) + >>> # dict_keys(['pixel_values']) + ```""" + + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), " + f"but is of type {type(images)}." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # Tesseract OCR to get words + normalized bounding boxes + if self.apply_ocr: + requires_backends(self, "pytesseract") + words_batch = [] + boxes_batch = [] + for image in images: + words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang) + words_batch.append(words) + boxes_batch.append(boxes) + + # transformations (resizing + normalization) + if self.do_resize and self.size is not None: + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if self.apply_ocr: + encoded_inputs["words"] = words_batch + encoded_inputs["boxes"] = boxes_batch + + return encoded_inputs diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py new file mode 100644 index 0000000000000..f301afafd80ec --- /dev/null +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -0,0 +1,1309 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch LayoutLMv3 model.""" + +import collections +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers import apply_chunking_to_forward +from transformers.modeling_outputs import ( + BaseModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from ...activations import ACT2FN +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from .configuration_layoutlmv3 import LayoutLMv3Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LayoutLMv3Config" + +LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/layoutlmv3-base", + "microsoft/layoutlmv3-large", + # See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3 +] + +LAYOUTLMV3_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`LayoutLMv2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +LAYOUTLMV3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`LayoutLMv2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Batch of document images. + + attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `{0}`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class LayoutLMv3PatchEmbeddings(nn.Module): + """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying + image sizes.""" + + def __init__(self, config): + super().__init__() + + image_size = ( + config.input_size + if isinstance(config.input_size, collections.abc.Iterable) + else (config.input_size, config.input_size) + ) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) + self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values, position_embedding=None): + embeddings = self.proj(pixel_values) + + if position_embedding is not None: + # interpolate the position embedding to the corresponding size + position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1) + position_embedding = position_embedding.permute(0, 3, 1, 2) + patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] + position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic") + embeddings = embeddings + position_embedding + + embeddings = embeddings.flatten(2).transpose(1, 2) + return embeddings + + +class LayoutLMv3TextEmbeddings(nn.Module): + """ + LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) + self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size) + self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) + self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size) + + def calculate_spatial_position_embeddings(self, bbox): + try: + left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) + upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) + right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) + lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) + except IndexError as e: + raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e + + h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023)) + w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023)) + + # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add) + spatial_position_embeddings = torch.cat( + [ + left_position_embeddings, + upper_position_embeddings, + right_position_embeddings, + lower_position_embeddings, + h_position_embeddings, + w_position_embeddings, + ], + dim=-1, + ) + return spatial_position_embeddings + + def create_position_ids_from_input_ids(self, input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask + return incremental_indices.long() + padding_idx + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + def forward( + self, + input_ids=None, + bbox=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to( + input_ids.device + ) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox) + + embeddings = embeddings + spatial_position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class LayoutLMv3PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LayoutLMv3Config + base_model_prefix = "layoutlmv3" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class LayoutLMv3SelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.has_relative_attention_bias = config.has_relative_attention_bias + self.has_spatial_attention_bias = config.has_spatial_attention_bias + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def cogview_attention(self, attention_scores, alpha=32): + """ + https://arxiv.org/abs/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation + (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs + will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs, + cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better. + """ + scaled_attention_scores = attention_scores / alpha + max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1) + new_attention_scores = (scaled_attention_scores - max_value) * alpha + return nn.Softmax(dim=-1)(new_attention_scores) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow. + # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf) + attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) + + if self.has_relative_attention_bias and self.has_spatial_attention_bias: + attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) + elif self.has_relative_attention_bias: + attention_scores += rel_pos / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + # Use the trick of the CogView paper to stablize training + attention_probs = self.cogview_attention(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput +class LayoutLMv3SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 +class LayoutLMv3Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = LayoutLMv3SelfAttention(config) + self.output = LayoutLMv3SelfOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 +class LayoutLMv3Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = LayoutLMv3Attention(config) + self.intermediate = LayoutLMv3Intermediate(config) + self.output = LayoutLMv3Output(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_2d_pos=None, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class LayoutLMv3Encoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + self.has_relative_attention_bias = config.has_relative_attention_bias + self.has_spatial_attention_bias = config.has_spatial_attention_bias + + if self.has_relative_attention_bias: + self.rel_pos_bins = config.rel_pos_bins + self.max_rel_pos = config.max_rel_pos + self.rel_pos_onehot_size = config.rel_pos_bins + self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False) + + if self.has_spatial_attention_bias: + self.max_rel_2d_pos = config.max_rel_2d_pos + self.rel_2d_pos_bins = config.rel_2d_pos_bins + self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins + self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) + self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) + + def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): + ret = 0 + if bidirectional: + num_buckets //= 2 + ret += (relative_position > 0).long() * num_buckets + n = torch.abs(relative_position) + else: + n = torch.max(-relative_position, torch.zeros_like(relative_position)) + # now n is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).to(torch.long) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def _cal_1d_pos_emb(self, hidden_states, position_ids): + rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) + + rel_pos = self.relative_position_bucket( + rel_pos_mat, + num_buckets=self.rel_pos_bins, + max_distance=self.max_rel_pos, + ) + rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states) + rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) + rel_pos = rel_pos.contiguous() + return rel_pos + + def _cal_2d_pos_emb(self, hidden_states, bbox): + position_coord_x = bbox[:, :, 0] + position_coord_y = bbox[:, :, 3] + rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1) + rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1) + rel_pos_x = self.relative_position_bucket( + rel_pos_x_2d_mat, + num_buckets=self.rel_2d_pos_bins, + max_distance=self.max_rel_2d_pos, + ) + rel_pos_y = self.relative_position_bucket( + rel_pos_y_2d_mat, + num_buckets=self.rel_2d_pos_bins, + max_distance=self.max_rel_2d_pos, + ) + rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) + rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states) + rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2) + rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2) + rel_pos_x = rel_pos_x.contiguous() + rel_pos_y = rel_pos_y.contiguous() + rel_2d_pos = rel_pos_x + rel_pos_y + return rel_2d_pos + + def forward( + self, + hidden_states, + bbox=None, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + position_ids=None, + patch_height=None, + patch_width=None, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None + rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + # return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos) + # The above line will cause error: + # RuntimeError: Trying to backward through the graph a second time + # (or directly access saved tensors after they have already been freed). + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos, + rel_2d_pos, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate +class LayoutLMv3Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput +class LayoutLMv3Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +@add_start_docstrings( + "The bare LayoutLMv3 Model transformer outputting raw hidden-states without any specific head on top.", + LAYOUTLMV3_START_DOCSTRING, +) +class LayoutLMv3Model(LayoutLMv3PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.config = config + + if config.text_embed: + self.embeddings = LayoutLMv3TextEmbeddings(config) + + if config.visual_embed: + # use the default pre-training parameters for fine-tuning (e.g., input_size) + # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward + self.patch_embed = LayoutLMv3PatchEmbeddings(config) + + size = int(config.input_size / config.patch_size) + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size)) + self.pos_drop = nn.Dropout(p=0.0) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: + self.init_visual_bbox(image_size=(size, size)) + + self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + + self.encoder = LayoutLMv3Encoder(config) + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def init_visual_bbox(self, image_size=(14, 14), max_len=1000): + """ + Create the bounding boxes for the visual (patch) tokens. + """ + visual_bbox_x = torch.div( + torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc" + ) + visual_bbox_y = torch.div( + torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc" + ) + visual_bbox = torch.stack( + [ + visual_bbox_x[:-1].repeat(image_size[0], 1), + visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1), + visual_bbox_x[1:].repeat(image_size[0], 1), + visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1), + ], + dim=-1, + ).view(-1, 4) + + cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]]) + self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0) + + def calculate_visual_bbox(self, device, dtype, batch_size): + visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1) + visual_bbox = visual_bbox.to(device).type(dtype) + return visual_bbox + + def forward_image(self, pixel_values): + embeddings = self.patch_embed(pixel_values) + + # add [CLS] token + batch_size, seq_len, _ = embeddings.size() + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add position embeddings + if self.pos_embed is not None: + embeddings = embeddings + self.pos_embed + + embeddings = self.pos_drop(embeddings) + embeddings = self.norm(embeddings) + + return embeddings + + @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + bbox=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + pixel_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModel + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) + >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + + >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif pixel_values is not None: + batch_size = len(pixel_values) + device = pixel_values.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values") + + if input_ids is not None or inputs_embeds is not None: + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + if bbox is None: + bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + bbox=bbox, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + final_bbox = final_position_ids = None + patch_height = patch_width = None + if pixel_values is not None: + patch_height, patch_width = int(pixel_values.shape[2] / self.config.patch_size), int( + pixel_values.shape[3] / self.config.patch_size + ) + visual_embeddings = self.forward_image(pixel_values) + visual_attention_mask = torch.ones( + (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device + ) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1) + else: + attention_mask = visual_attention_mask + + if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: + if self.config.has_spatial_attention_bias: + visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size) + if bbox is not None: + final_bbox = torch.cat([bbox, visual_bbox], dim=1) + else: + final_bbox = visual_bbox + + visual_position_ids = torch.arange( + 0, visual_embeddings.shape[1], dtype=torch.long, device=device + ).repeat(batch_size, 1) + if input_ids is not None or inputs_embeds is not None: + position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0) + position_ids = position_ids.expand(input_shape) + final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1) + else: + final_position_ids = visual_position_ids + + if input_ids is not None or inputs_embeds is not None: + embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1) + else: + embedding_output = visual_embeddings + + embedding_output = self.LayerNorm(embedding_output) + embedding_output = self.dropout(embedding_output) + elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias: + if self.config.has_spatial_attention_bias: + final_bbox = bbox + if self.config.has_relative_attention_bias: + position_ids = self.embeddings.position_ids[:, : input_shape[1]] + position_ids = position_ids.expand_as(input_ids) + final_position_ids = position_ids + + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + bbox=final_bbox, + position_ids=final_position_ids, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + patch_height=patch_height, + patch_width=patch_width, + ) + + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class LayoutLMv3ClassificationHead(nn.Module): + """ + Head for sentence-level classification tasks. Reference: RobertaClassificationHead + """ + + def __init__(self, config, pool_feature=False): + super().__init__() + self.pool_feature = pool_feature + if pool_feature: + self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, x): + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g. + for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/), + [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and + [Kleister-NDA](https://github.com/applicaai/kleister-nda). + """, + LAYOUTLMV3_START_DOCSTRING, +) +class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.layoutlmv3 = LayoutLMv3Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.num_labels < 10: + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + else: + self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + bbox=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + pixel_values=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModelForTokenClassification + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) + >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7) + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> word_labels = example["ner_tags"] + + >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlmv3( + input_ids, + bbox=bbox, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values=pixel_values, + ) + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + # only take the text part of the output representations + sequence_output = outputs[0][:, :seq_length] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + LayoutLMv3 Model with a span classification head on top for extractive question-answering tasks such as + [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to + compute `span start logits` and `span end logits`). + """, + LAYOUTLMV3_START_DOCSTRING, +) +class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.layoutlmv3 = LayoutLMv3Model(config) + self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + bbox=None, + pixel_values=None, + ): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering + >>> from datasets import load_dataset + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) + >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> question = "what's his name?" + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + + >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") + >>> start_positions = torch.tensor([1]) + >>> end_positions = torch.tensor([3]) + + >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions) + >>> loss = outputs.loss + >>> start_scores = outputs.start_logits + >>> end_scores = outputs.end_logits + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlmv3( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + bbox=bbox, + pixel_values=pixel_values, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the + [CLS] token) e.g. for document image classification tasks such as the + [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset. + """, + LAYOUTLMV3_START_DOCSTRING, +) +class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.layoutlmv3 = LayoutLMv3Model(config) + self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False) + + self.init_weights() + + @add_start_docstrings_to_model_forward(LAYOUTLMV3_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + bbox=None, + pixel_values=None, + ): + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModelForSequenceClassification + >>> from datasets import load_dataset + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False) + >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + + >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") + >>> sequence_label = torch.tensor([1]) + + >>> outputs = model(**encoding, labels=sequence_label) + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.layoutlmv3( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + bbox=bbox, + pixel_values=pixel_values, + ) + + sequence_output = outputs[0][:, 0, :] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py new file mode 100644 index 0000000000000..c80b2bd5f2030 --- /dev/null +++ b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for LayoutLMv3. +""" +from typing import List, Optional, Union + +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class LayoutLMv3Processor(ProcessorMixin): + r""" + Constructs a LayoutLMv3 processor which combines a LayoutLMv3 feature extractor and a LayoutLMv3 tokenizer into a + single processor. + + [`LayoutLMv3Processor`] offers all the functionalities you need to prepare data for the model. + + It first uses [`LayoutLMv3FeatureExtractor`] to resize and normalize document images, and optionally applies OCR to + get words and normalized bounding boxes. These are then provided to [`LayoutLMv3Tokenizer`] or + [`LayoutLMv3TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`, + `attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned + into token-level `labels` for token classification tasks (such as FUNSD, CORD). + + Args: + feature_extractor (`LayoutLMv3FeatureExtractor`): + An instance of [`LayoutLMv3FeatureExtractor`]. The feature extractor is a required input. + tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`): + An instance of [`LayoutLMv3Tokenizer`] or [`LayoutLMv3TokenizerFast`]. The tokenizer is a required input. + """ + feature_extractor_class = "LayoutLMv3FeatureExtractor" + tokenizer_class = ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast") + + def __call__( + self, + images, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchEncoding: + """ + This method first forwards the `images` argument to [`~LayoutLMv3FeatureExtractor.__call__`]. In case + [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and + bounding boxes along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output, + together with resized and normalized `pixel_values`. In case [`LayoutLMv3FeatureExtractor`] was initialized + with `apply_ocr` set to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user + along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output, together with + resized and normalized `pixel_values`. + + Please refer to the docstring of the above two methods for more information. + """ + # verify input + if self.feature_extractor.apply_ocr and (boxes is not None): + raise ValueError( + "You cannot provide bounding boxes " + "if you initialized the feature extractor with apply_ocr set to True." + ) + + if self.feature_extractor.apply_ocr and (word_labels is not None): + raise ValueError( + "You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True." + ) + + # first, apply the feature extractor + features = self.feature_extractor(images=images, return_tensors=return_tensors) + + # second, apply the tokenizer + if text is not None and self.feature_extractor.apply_ocr and text_pair is None: + if isinstance(text, str): + text = [text] # add batch dimension (as the feature extractor always adds a batch dimension) + text_pair = features["words"] + + encoded_inputs = self.tokenizer( + text=text if text is not None else features["words"], + text_pair=text_pair if text_pair is not None else None, + boxes=boxes if boxes is not None else features["boxes"], + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + **kwargs, + ) + + # add pixel values + images = features.pop("pixel_values") + if return_overflowing_tokens is True: + images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) + encoded_inputs["pixel_values"] = images + + return encoded_inputs + + def get_overflowing_images(self, images, overflow_to_sample_mapping): + # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image + images_with_overflow = [] + for sample_idx in overflow_to_sample_mapping: + images_with_overflow.append(images[sample_idx]) + + if len(images_with_overflow) != len(overflow_to_sample_mapping): + raise ValueError( + "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got" + f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}" + ) + + return images_with_overflow + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py new file mode 100644 index 0000000000000..8b0cc5fbf7428 --- /dev/null +++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py @@ -0,0 +1,1478 @@ +# coding=utf-8 +# Copyright The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for LayoutLMv3. Same as LayoutLMv2, but RoBERTa-like BPE tokenization instead of WordPiece.""" + +import json +import os +from functools import lru_cache +from typing import Dict, List, Optional, Tuple, Union + +import regex as re + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "merges_file": "merges.txt", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json", + "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json", + }, + "merges_file": { + "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt", + "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "microsoft/layoutlmv3-base": 512, + "microsoft/layoutlmv3-large": 512, +} + + +LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + + +LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to + `None`, this will use the predefined model maximum length if a maximum length is required by one of the + truncation/padding parameters. If the model has no specific maximum input length (like XLNet) + truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + + +@lru_cache() +# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +# Copied from transformers.models.roberta.tokenization_roberta.get_pairs +def get_pairs(word): + """ + Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class LayoutLMv3Tokenizer(PreTrainedTokenizer): + r""" + Construct a LayoutLMv3 tokenizer. Based on [`RoBERTatokenizer`] (Byte Pair Encoding or BPE). + [`LayoutLMv3Tokenizer`] can be used to turn words, word-level bounding boxes and optional word labels to + token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`, and optional `labels` (for token + classification). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + [`LayoutLMv3Tokenizer`] runs end-to-end tokenization: punctuation splitting and wordpiece. It also turns the + word-level bounding boxes into token-level bounding boxes. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [CLS] token. + sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + """ + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask", "bbox"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=True, + cls_token_box=[0, 0, 0, 0], + sep_token_box=[0, 0, 0, 0], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + **kwargs + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + super().__init__( + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + cls_token_box=cls_token_box, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + **kwargs, + ) + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + # additional properties + self.cls_token_box = cls_token_box + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + @property + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size + def vocab_size(self): + return len(self.encoder) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not + make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + # If the text starts with a token that should not be split, no space is added before the text in any case. + # It's necessary to match the fast tokenization + if ( + (is_split_into_words or add_prefix_space) + and (len(text) > 0 and not text[0].isspace()) + and sum([text.startswith(no_split_token) for no_split_token in self.unique_no_split_tokens]) == 0 + ): + text = " " + text + return (text, kwargs) + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.__call__ + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.batch_encode_plus + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_encode_plus + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + batch_outputs = self._batch_prepare_for_model( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._batch_prepare_for_model + def _batch_prepare_for_model( + self, + batch_text_or_text_pairs, + is_pair: bool = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens. + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)): + batch_text_or_text_pair, boxes_example = example + outputs = self.prepare_for_model( + batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair, + batch_text_or_text_pair[1] if is_pair else None, + boxes_example, + word_labels=word_labels[idx] if word_labels is not None else None, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode + def encode( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> List[int]: + encoded_inputs = self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.encode_plus + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + boxes=boxes, + text_pair=text_pair, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._encode_plus + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + return self.prepare_for_model( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ + Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens, + truncates sequences if overflowing while taking into account the special tokens and manages a moving window + (with user defined stride) for overflowing tokens. Please Note, for *text_pair* different than `None` and + *truncation_strategy = longest_first* or `True`, it is not possible to return overflowing tokens. Such a + combination of arguments will raise an error. + + Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into + token-level `labels`. The word label is used for the first token of the word, while remaining tokens are + labeled with -100, such that they will be ignored by the loss function. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + tokens = [] + pair_tokens = [] + token_boxes = [] + pair_token_boxes = [] + labels = [] + + if text_pair is None: + if word_labels is None: + # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference) + for word, box in zip(text, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + else: + # CASE 2: token classification (training) + for word, box, label in zip(text, boxes, word_labels): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + if self.only_label_first_subword: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1)) + else: + labels.extend([label] * len(word_tokens)) + else: + # CASE 3: document visual question answering (inference) + # text = question + # text_pair = words + tokens = self.tokenize(text) + token_boxes = [self.pad_token_box for _ in range(len(tokens))] + + for word, box in zip(text_pair, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + pair_tokens.extend(word_tokens) + pair_token_boxes.extend([box] * len(word_tokens)) + + # Create ids + pair_ids + ids = self.convert_tokens_to_ids(tokens) + pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None + + if ( + return_overflowing_tokens + and truncation_strategy == TruncationStrategy.LONGEST_FIRST + and pair_ids is not None + ): + raise ValueError( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + + # Compute the total size of the returned encodings + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) = self.truncate_sequences( + ids, + token_boxes, + pair_ids=pair_ids, + pair_token_boxes=pair_token_boxes, + labels=labels, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes + encoded_inputs["overflowing_labels"] = overflowing_labels + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + token_boxes = [self.cls_token_box] + token_boxes + [self.sep_token_box] + if pair_token_boxes: + pair_token_boxes = [self.sep_token_box] + pair_token_boxes + [self.sep_token_box] + token_boxes = token_boxes + pair_token_boxes if pair else token_boxes + if labels: + labels = [self.pad_token_label] + labels + [self.pad_token_label] + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + token_boxes = token_boxes + pair_token_boxes if pair else token_boxes + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["bbox"] = token_boxes + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + if labels: + encoded_inputs["labels"] = labels + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer.truncate_sequences + def truncate_sequences( + self, + ids: List[int], + token_boxes: List[List[int]], + pair_ids: Optional[List[int]] = None, + pair_token_boxes: Optional[List[List[int]]] = None, + labels: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + token_boxes (`List[List[int]]`): + Bounding boxes of the first sequence. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + pair_token_boxes (`List[List[int]]`, *optional*): + Bounding boxes of the second sequence. + labels (`List[int]`, *optional*): + Labels of the first sequence (for token classification tasks). + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair + of sequences (or a batch of pairs) is provided. + """ + if num_tokens_to_remove <= 0: + return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( + truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None + ): + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + overflowing_token_boxes = token_boxes[-window_len:] + overflowing_labels = labels[-window_len:] + ids = ids[:-num_tokens_to_remove] + token_boxes = token_boxes[:-num_tokens_to_remove] + labels = labels[:-num_tokens_to_remove] + else: + error_msg = ( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + ) + if truncation_strategy == TruncationStrategy.ONLY_FIRST: + error_msg = ( + error_msg + + "Please select another truncation strategy than " + f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." + ) + logger.error(error_msg) + elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: + logger.warning( + "Be aware, overflowing tokens are not returned for the setting you have chosen," + f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' " + "truncation strategy. So the returned list will always be empty even if some " + "tokens have been removed." + ) + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + ids = ids[:-1] + token_boxes = token_boxes[:-1] + labels = labels[:-1] + else: + pair_ids = pair_ids[:-1] + pair_token_boxes = pair_token_boxes[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + overflowing_token_boxes = pair_token_boxes[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2.LayoutLMv2Tokenizer._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py new file mode 100644 index 0000000000000..be5f938dbf17c --- /dev/null +++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py @@ -0,0 +1,853 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fast tokenization class for LayoutLMv3. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus +and _encode_plus, in which the Rust tokenizer is used. +""" + +import json +from typing import Dict, List, Optional, Tuple, Union + +from tokenizers import pre_tokenizers, processors + +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PaddingStrategy, + PreTokenizedInput, + TensorType, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import add_end_docstrings, logging +from .tokenization_layoutlmv3 import ( + LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, + LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, + LayoutLMv3Tokenizer, +) + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json", + "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json", + }, + "merges_file": { + "microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt", + "microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "microsoft/layoutlmv3-base": 512, + "microsoft/layoutlmv3-large": 512, +} + + +class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" LayoutLMv3 tokenizer (backed by HuggingFace's *tokenizers* library). Based on BPE. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (RoBERTa tokenizer detect beginning of words by the preceding space). + trim_offsets (`bool`, *optional*, defaults to `True`): + Whether the post processing step should trim offsets to avoid including whitespaces. + cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [CLS] token. + sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = LayoutLMv3Tokenizer + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=True, + trim_offsets=True, + cls_token_box=[0, 0, 0, 0], + sep_token_box=[0, 0, 0, 0], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + **kwargs + ): + super().__init__( + vocab_file, + merges_file, + tokenizer_file=tokenizer_file, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + cls_token=cls_token, + unk_token=unk_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + cls_token_box=cls_token_box, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + **kwargs, + ) + + pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) + if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) + pre_tok_state["add_prefix_space"] = add_prefix_space + self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) + + self.add_prefix_space = add_prefix_space + + tokenizer_component = "post_processor" + tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) + if tokenizer_component_instance: + state = json.loads(tokenizer_component_instance.__getstate__()) + + # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class` + if "sep" in state: + state["sep"] = tuple(state["sep"]) + if "cls" in state: + state["cls"] = tuple(state["cls"]) + + changes_to_apply = False + + if state.get("add_prefix_space", add_prefix_space) != add_prefix_space: + state["add_prefix_space"] = add_prefix_space + changes_to_apply = True + + if state.get("trim_offsets", trim_offsets) != trim_offsets: + state["trim_offsets"] = trim_offsets + changes_to_apply = True + + if changes_to_apply: + component_class = getattr(processors, state.pop("type")) + new_value = component_class(**state) + setattr(self.backend_tokenizer, tokenizer_component, new_value) + + # additional properties + self.cls_token_box = cls_token_box + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.__call__ + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must be of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.batch_encode_plus + def batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.tokenize + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + batched_input = [(text, pair)] if pair else [text] + encodings = self._tokenizer.encode_batch( + batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs + ) + + return encodings[0].tokens + + @add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.encode_plus + def encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated, + `__call__` should be used instead. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + text=text, + boxes=boxes, + text_pair=text_pair, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._batch_encode_plus with LayoutLMv2->LayoutLMv3 + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})") + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + ) + + if is_pair: + batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs] + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=True, # we set this to True as LayoutLMv3 always expects pretokenized inputs + ) + + # Convert encoding to dict + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=True + if word_labels is not None + else return_offsets_mapping, # we use offsets to create the labels + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + + # create the token boxes + token_boxes = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + token_boxes_example = [] + for id, sequence_id, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_encodings[batch_index].sequence_ids, + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if is_pair and sequence_id == 0: + token_boxes_example.append(self.pad_token_box) + else: + token_boxes_example.append(boxes[original_index][word_id]) + else: + if id == self.cls_token_id: + token_boxes_example.append(self.cls_token_box) + elif id == self.sep_token_id: + token_boxes_example.append(self.sep_token_box) + elif id == self.pad_token_id: + token_boxes_example.append(self.pad_token_box) + else: + raise ValueError("Id not recognized") + token_boxes.append(token_boxes_example) + + sanitized_tokens["bbox"] = token_boxes + + # optionally, create the labels + if word_labels is not None: + labels = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + labels_example = [] + for id, offset, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_tokens["offset_mapping"][batch_index], + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if self.only_label_first_subword: + if offset[0] == 0: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + else: + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + labels.append(labels_example) + + sanitized_tokens["labels"] = labels + # finally, remove offsets if the user didn't want them + if not return_offsets_mapping: + del sanitized_tokens["offset_mapping"] + + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._encode_plus + def _encode_plus( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + # make it a batched input + # 2 options: + # 1) only text, in case text must be a list of str + # 2) text + text_pair, in which case text = str and text_pair a list of str + batched_input = [(text, text_pair)] if text_pair else [text] + batched_boxes = [boxes] + batched_word_labels = [word_labels] if word_labels is not None else None + batched_output = self._batch_encode_plus( + batched_input, + is_pair=bool(text_pair is not None), + boxes=batched_boxes, + word_labels=batched_word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs + + # Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Args: + Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not: + make use of token type ids, therefore a list of zeros is returned. + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] diff --git a/src/transformers/models/layoutxlm/__init__.py b/src/transformers/models/layoutxlm/__init__.py index 34bcddb4a94ee..a368067146ef1 100644 --- a/src/transformers/models/layoutxlm/__init__.py +++ b/src/transformers/models/layoutxlm/__init__.py @@ -28,7 +28,9 @@ ) -_import_structure = {} +_import_structure = { + "processing_layoutxlm": ["LayoutXLMProcessor"], +} try: if not is_sentencepiece_available(): @@ -46,15 +48,9 @@ else: _import_structure["tokenization_layoutxlm_fast"] = ["LayoutXLMTokenizerFast"] -try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["processing_layoutxlm"] = ["LayoutXLMProcessor"] - if TYPE_CHECKING: + from .processing_layoutxlm import LayoutXLMProcessor + try: if not is_sentencepiece_available(): raise OptionalDependencyNotAvailable() @@ -71,14 +67,6 @@ else: from .tokenization_layoutxlm_fast import LayoutXLMTokenizerFast - try: - if not is_vision_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .processing_layoutlmv2 import LayoutXLMProcessor - else: import sys diff --git a/src/transformers/models/layoutxlm/processing_layoutxlm.py b/src/transformers/models/layoutxlm/processing_layoutxlm.py index 6f45ee065957a..03423d17c27be 100644 --- a/src/transformers/models/layoutxlm/processing_layoutxlm.py +++ b/src/transformers/models/layoutxlm/processing_layoutxlm.py @@ -121,6 +121,37 @@ def __call__( ) # add pixel values - encoded_inputs["image"] = features.pop("pixel_values") + images = features.pop("pixel_values") + if return_overflowing_tokens is True: + images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"]) + encoded_inputs["image"] = images return encoded_inputs + + def get_overflowing_images(self, images, overflow_to_sample_mapping): + # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image + images_with_overflow = [] + for sample_idx in overflow_to_sample_mapping: + images_with_overflow.append(images[sample_idx]) + + if len(images_with_overflow) != len(overflow_to_sample_mapping): + raise ValueError( + "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got" + f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}" + ) + + return images_with_overflow + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 8f7e291beac6a..c36546680dd66 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2367,6 +2367,44 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class LayoutLMv3ForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LayoutLMv3ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LayoutLMv3ForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LayoutLMv3Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LayoutLMv3PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + LED_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 12cec6a4a2605..7edd48cb57753 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -171,6 +171,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class LayoutLMv3TokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class LayoutXLMTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index b9c7d3e97f606..add3fada5b8e2 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -94,14 +94,7 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) -class LayoutLMv2Processor(metaclass=DummyObject): - _backends = ["vision"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["vision"]) - - -class LayoutXLMProcessor(metaclass=DummyObject): +class LayoutLMv3FeatureExtractor(metaclass=DummyObject): _backends = ["vision"] def __init__(self, *args, **kwargs): diff --git a/tests/models/layoutlmv2/test_processor_layoutlmv2.py b/tests/models/layoutlmv2/test_processor_layoutlmv2.py index 9b269d2eb216f..4f686155adc71 100644 --- a/tests/models/layoutlmv2/test_processor_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_processor_layoutlmv2.py @@ -215,10 +215,11 @@ def test_processor_case_1(self): ) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = "[CLS] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -236,10 +237,11 @@ def test_processor_case_1(self): ) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = "[CLS] 7 itc limited report and accounts 2013 itc ’ s brands : an asset for the nation the consumer needs and aspirations they fulfil, the benefit they generate for millions across itc ’ s value chains, the future - ready capabilities that support them, and the value that they create for the country, have made itc ’ s brands national assets, adding to india ’ s competitiveness. it is itc ’ s aspiration to be the no 1 fmcg player in the country, driven by its new fmcg businesses. a recent nielsen report has highlighted that itc's new fmcg businesses are the fastest growing among the top consumer goods companies operating in india. itc takes justifiable pride that, along with generating economic value, these celebrated indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. di wills * ; love delightfully soft skin? aia ans source : https : / / www. industrydocuments. ucsf. edu / docs / snbx0223 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) @slow @@ -266,7 +268,7 @@ def test_processor_case_2(self): # verify input_ids expected_decoding = "[CLS] hello world [SEP]" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -281,7 +283,7 @@ def test_processor_case_2(self): # verify input_ids expected_decoding = "[CLS] hello world [SEP] [PAD] [PAD] [PAD]" - decoding = tokenizer.decode(input_processor.input_ids[0].tolist()) + decoding = processor.decode(input_processor.input_ids[0].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -320,7 +322,7 @@ def test_processor_case_3(self): # verify input_ids expected_decoding = "[CLS] weirdly world [SEP]" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify labels @@ -342,7 +344,7 @@ def test_processor_case_3(self): # verify input_ids expected_decoding = "[CLS] my name is niels [SEP]" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -382,10 +384,11 @@ def test_processor_case_4(self): self.assertListEqual(actual_keys, expected_keys) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = "[CLS] what's his name? [SEP] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -400,8 +403,9 @@ def test_processor_case_4(self): self.assertListEqual(actual_keys, expected_keys) # verify input_ids + # this was obtained with Tesseract 4.1.1 expected_decoding = "[CLS] what's the time [SEP] 7 itc limited report and accounts 2013 itc ’ s [SEP]" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -434,7 +438,7 @@ def test_processor_case_5(self): # verify input_ids expected_decoding = "[CLS] what's his name? [SEP] hello world [SEP]" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -450,11 +454,11 @@ def test_processor_case_5(self): # verify input_ids expected_decoding = "[CLS] how old is he? [SEP] hello world [SEP] [PAD] [PAD] [PAD]" - decoding = tokenizer.decode(input_processor.input_ids[0].tolist()) + decoding = processor.decode(input_processor.input_ids[0].tolist()) self.assertSequenceEqual(decoding, expected_decoding) expected_decoding = "[CLS] what's the time [SEP] my name is niels [SEP]" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox diff --git a/tests/models/layoutlmv3/__init__.py b/tests/models/layoutlmv3/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py new file mode 100644 index 0000000000000..9d05a4b6658ef --- /dev/null +++ b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.testing_utils import require_pytesseract, require_torch +from transformers.utils import is_pytesseract_available, is_torch_available + +from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_pytesseract_available(): + from PIL import Image + + from transformers import LayoutLMv3FeatureExtractor + + +class LayoutLMv3FeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=18, + apply_ocr=True, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.apply_ocr = apply_ocr + + def prepare_feat_extract_dict(self): + return {"do_resize": self.do_resize, "size": self.size, "apply_ocr": self.apply_ocr} + + +@require_torch +@require_pytesseract +class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = LayoutLMv3FeatureExtractor if is_pytesseract_available() else None + + def setUp(self): + self.feature_extract_tester = LayoutLMv3FeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "apply_ocr")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoding = feature_extractor(image_inputs[0], return_tensors="pt") + self.assertEqual( + encoding.pixel_values.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + self.assertIsInstance(encoding.words, list) + self.assertIsInstance(encoding.boxes, list) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.size, + self.feature_extract_tester.size, + ), + ) + + def test_LayoutLMv3_integration_test(self): + # with apply_OCR = True + feature_extractor = LayoutLMv3FeatureExtractor() + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test") + + image = Image.open(ds[0]["file"]).convert("RGB") + + encoding = feature_extractor(image, return_tensors="pt") + + self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224)) + self.assertEqual(len(encoding.words), len(encoding.boxes)) + + # fmt: off + # the words and boxes were obtained with Tesseract 4.1.1 + expected_words = [['11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.', '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58', 'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00', 'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for', 'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to', 'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.', '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)', 'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President', '“Introductory', 'Remarks”', 'Lee', 'A.', 'Waller,', 'TRRF', 'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with', 'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific', 'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF', 'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers', 'which', 'the', 'public', 'refrigerated', 'warehousing', 'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from', 'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of', 'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.', 'Cecil,', 'University', 'of', 'Georgia', 'College', 'of', 'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University', 'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,', 'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen', 'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert', 'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers', 'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF', 'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer']] # noqa: E231 + expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69], [141, 75, 216, 88], [230, 79, 280, 88], [142, 260, 218, 273], [230, 261, 255, 273], [143, 279, 218, 290], [231, 282, 290, 291], [143, 342, 218, 354], [231, 345, 289, 355], [202, 362, 227, 373], [143, 379, 220, 392], [231, 382, 291, 394], [144, 714, 220, 726], [231, 715, 256, 726], [144, 732, 220, 745], [232, 736, 291, 747], [144, 769, 218, 782], [231, 770, 256, 782], [141, 788, 202, 801], [215, 791, 274, 804], [143, 826, 204, 838], [215, 826, 240, 838], [142, 844, 202, 857], [215, 847, 274, 859], [334, 57, 427, 69], [440, 57, 522, 69], [369, 75, 461, 88], [469, 75, 516, 88], [528, 76, 562, 88], [570, 76, 667, 88], [675, 75, 711, 87], [721, 79, 778, 88], [789, 75, 840, 88], [369, 97, 470, 107], [484, 94, 507, 106], [518, 94, 562, 107], [576, 94, 655, 110], [668, 94, 792, 109], [804, 95, 829, 107], [369, 113, 465, 125], [477, 116, 547, 125], [562, 113, 658, 125], [671, 116, 748, 125], [761, 113, 811, 125], [369, 131, 465, 143], [477, 133, 548, 143], [563, 130, 698, 145], [710, 130, 802, 146], [336, 171, 412, 183], [423, 171, 572, 183], [582, 170, 716, 184], [728, 171, 817, 187], [829, 171, 844, 186], [338, 197, 482, 212], [507, 196, 557, 209], [569, 196, 595, 208], [610, 196, 702, 209], [505, 214, 583, 226], [595, 214, 656, 227], [670, 215, 807, 227], [335, 259, 543, 274], [556, 259, 708, 272], [372, 279, 422, 291], [435, 279, 460, 291], [474, 279, 574, 292], [587, 278, 664, 291], [676, 278, 738, 291], [751, 279, 834, 291], [372, 298, 434, 310], [335, 341, 483, 354], [497, 341, 655, 354], [667, 341, 728, 354], [740, 341, 825, 354], [335, 360, 430, 372], [442, 360, 534, 372], [545, 359, 687, 372], [697, 360, 754, 372], [765, 360, 823, 373], [334, 378, 428, 391], [440, 378, 577, 394], [590, 378, 705, 391], [720, 378, 801, 391], [334, 397, 400, 409], [370, 416, 529, 429], [544, 416, 576, 432], [587, 416, 665, 428], [677, 416, 814, 429], [372, 435, 452, 450], [465, 434, 495, 447], [511, 434, 600, 447], [611, 436, 637, 447], [649, 436, 694, 451], [705, 438, 824, 447], [369, 453, 452, 466], [464, 454, 509, 466], [522, 453, 611, 469], [625, 453, 792, 469], [370, 472, 556, 488], [570, 472, 684, 487], [697, 472, 718, 485], [732, 472, 835, 488], [369, 490, 411, 503], [425, 490, 484, 503], [496, 490, 635, 506], [645, 490, 707, 503], [718, 491, 761, 503], [771, 490, 840, 503], [336, 510, 374, 521], [388, 510, 447, 522], [460, 510, 489, 521], [503, 510, 580, 522], [592, 509, 736, 525], [745, 509, 770, 522], [781, 509, 840, 522], [338, 528, 434, 541], [448, 528, 596, 541], [609, 527, 687, 540], [700, 528, 792, 541], [336, 546, 397, 559], [407, 546, 431, 559], [443, 546, 525, 560], [537, 546, 680, 562], [688, 546, 714, 559], [722, 546, 837, 562], [336, 565, 449, 581], [461, 565, 485, 577], [497, 565, 665, 581], [681, 565, 718, 577], [732, 565, 837, 580], [337, 584, 438, 597], [452, 583, 521, 596], [535, 584, 677, 599], [690, 583, 787, 596], [801, 583, 825, 596], [338, 602, 478, 615], [492, 602, 530, 614], [543, 602, 638, 615], [650, 602, 676, 614], [688, 602, 788, 615], [802, 602, 843, 614], [337, 621, 502, 633], [516, 621, 615, 637], [629, 621, 774, 636], [789, 621, 827, 633], [337, 639, 418, 652], [432, 640, 571, 653], [587, 639, 731, 655], [743, 639, 769, 652], [780, 639, 841, 652], [338, 658, 440, 673], [455, 658, 491, 670], [508, 658, 602, 671], [616, 658, 638, 670], [654, 658, 835, 674], [337, 677, 429, 689], [337, 714, 482, 726], [495, 714, 548, 726], [561, 714, 683, 726], [338, 770, 461, 782], [474, 769, 554, 785], [489, 788, 562, 803], [576, 788, 643, 801], [656, 787, 751, 804], [764, 788, 844, 801], [334, 825, 421, 838], [430, 824, 574, 838], [584, 824, 723, 841], [335, 844, 450, 857], [464, 843, 583, 860], [628, 862, 755, 875], [769, 861, 848, 878]]] # noqa: E231 + # fmt: on + + self.assertListEqual(encoding.words, expected_words) + self.assertListEqual(encoding.boxes, expected_boxes) + + # with apply_OCR = False + feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) + + encoding = feature_extractor(image, return_tensors="pt") + + self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224)) diff --git a/tests/models/layoutlmv3/test_modeling_layoutlmv3.py b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py new file mode 100644 index 0000000000000..d5c8d42d22177 --- /dev/null +++ b/tests/models/layoutlmv3/test_modeling_layoutlmv3.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch LayoutLMv3 model. """ + +import copy +import unittest + +from transformers.models.auto import get_values +from transformers.testing_utils import require_torch, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + LayoutLMv3Config, + LayoutLMv3ForQuestionAnswering, + LayoutLMv3ForSequenceClassification, + LayoutLMv3ForTokenClassification, + LayoutLMv3Model, + ) + from transformers.models.layoutlmv3.modeling_layoutlmv3 import LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST + +if is_vision_available(): + from PIL import Image + + from transformers import LayoutLMv3FeatureExtractor + + +class LayoutLMv3ModelTester: + def __init__( + self, + parent, + batch_size=2, + num_channels=3, + image_size=4, + patch_size=2, + text_seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=36, + num_hidden_layers=3, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + coordinate_size=6, + shape_size=6, + num_labels=3, + num_choices=4, + scope=None, + range_bbox=1000, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.text_seq_length = text_seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.coordinate_size = coordinate_size + self.shape_size = shape_size + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.range_bbox = range_bbox + + # LayoutLMv3's sequence length equals the number of text tokens + number of patches + 1 (we add 1 for the CLS token) + self.text_seq_length = text_seq_length + self.image_seq_length = (image_size // patch_size) ** 2 + 1 + self.seq_length = self.text_seq_length + self.image_seq_length + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.text_seq_length], self.vocab_size) + + bbox = ids_tensor([self.batch_size, self.text_seq_length, 4], self.range_bbox) + # Ensure that bbox is legal + for i in range(bbox.shape[0]): + for j in range(bbox.shape[1]): + if bbox[i, j, 3] < bbox[i, j, 1]: + t = bbox[i, j, 3] + bbox[i, j, 3] = bbox[i, j, 1] + bbox[i, j, 1] = t + if bbox[i, j, 2] < bbox[i, j, 0]: + t = bbox[i, j, 2] + bbox[i, j, 2] = bbox[i, j, 0] + bbox[i, j, 0] = t + + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.text_seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.text_seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels) + + config = LayoutLMv3Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + coordinate_size=self.coordinate_size, + shape_size=self.shape_size, + input_size=self.image_size, + patch_size=self.patch_size, + ) + + return config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels + + def create_and_check_model( + self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels + ): + model = LayoutLMv3Model(config=config) + model.to(torch_device) + model.eval() + + # text + image + result = model(input_ids, pixel_values=pixel_values) + result = model( + input_ids, bbox=bbox, pixel_values=pixel_values, attention_mask=input_mask, token_type_ids=token_type_ids + ) + result = model(input_ids, bbox=bbox, pixel_values=pixel_values, token_type_ids=token_type_ids) + result = model(input_ids, bbox=bbox, pixel_values=pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + # text only + result = model(input_ids) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size) + ) + + # image only + result = model(pixel_values=pixel_values) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.image_seq_length, self.hidden_size) + ) + + def create_and_check_for_sequence_classification( + self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels + ): + config.num_labels = self.num_labels + model = LayoutLMv3ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + bbox=bbox, + pixel_values=pixel_values, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=sequence_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels + ): + config.num_labels = self.num_labels + model = LayoutLMv3ForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + bbox=bbox, + pixel_values=pixel_values, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.num_labels)) + + def create_and_check_for_question_answering( + self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels + ): + model = LayoutLMv3ForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + bbox=bbox, + pixel_values=pixel_values, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + bbox, + pixel_values, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + ) = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "bbox": bbox, + "pixel_values": pixel_values, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + } + return config, inputs_dict + + +@require_torch +class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase): + + test_pruning = False + test_torchscript = False + test_mismatched_shapes = False + + all_model_classes = ( + ( + LayoutLMv3Model, + LayoutLMv3ForSequenceClassification, + LayoutLMv3ForTokenClassification, + LayoutLMv3ForQuestionAnswering, + ) + if is_torch_available() + else () + ) + + def setUp(self): + self.model_tester = LayoutLMv3ModelTester(self) + self.config_tester = ConfigTester(self, config_class=LayoutLMv3Config, hidden_size=37) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = copy.deepcopy(inputs_dict) + if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + inputs_dict = { + k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() + if isinstance(v, torch.Tensor) and v.ndim > 1 + else v + for k, v in inputs_dict.items() + } + if return_labels: + if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): + inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device) + elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + elif model_class in [ + *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING), + ]: + inputs_dict["labels"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + elif model_class in [ + *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), + ]: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.text_seq_length), + dtype=torch.long, + device=torch_device, + ) + + return inputs_dict + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = LayoutLMv3Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +class LayoutLMv3ModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return LayoutLMv3FeatureExtractor(apply_ocr=False) if is_vision_available() else None + + @slow + def test_inference_no_head(self): + model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device) + + input_ids = torch.tensor([[1, 2]]) + bbox = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).unsqueeze(0) + + # forward pass + outputs = model( + input_ids=input_ids.to(torch_device), + bbox=bbox.to(torch_device), + pixel_values=pixel_values.to(torch_device), + ) + + # verify the logits + expected_shape = torch.Size((1, 199, 768)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[-0.0529, 0.3618, 0.1632], [-0.1587, -0.1667, -0.0400], [-0.1557, -0.1671, -0.0505]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/layoutlmv3/test_processor_layoutlmv3.py b/tests/models/layoutlmv3/test_processor_layoutlmv3.py new file mode 100644 index 0000000000000..a01b0a00cd904 --- /dev/null +++ b/tests/models/layoutlmv3/test_processor_layoutlmv3.py @@ -0,0 +1,446 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest +from typing import List + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast +from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast +from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES +from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow +from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available + + +if is_pytesseract_available(): + from PIL import Image + + from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3Processor + + +@require_pytesseract +@require_tokenizers +class LayoutLMv3ProcessorTest(unittest.TestCase): + tokenizer_class = LayoutLMv3Tokenizer + rust_tokenizer_class = LayoutLMv3TokenizerFast + + def setUp(self): + # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + self.tmpdirname = tempfile.mkdtemp() + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + feature_extractor_map = { + "do_resize": True, + "size": 224, + "apply_ocr": True, + } + + self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(feature_extractor_map) + "\n") + + def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: + return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: + return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]: + return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] + + def get_feature_extractor(self, **kwargs): + return LayoutLMv3FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def test_save_load_pretrained_default(self): + feature_extractor = self.get_feature_extractor() + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + processor.save_pretrained(self.tmpdirname) + processor = LayoutLMv3Processor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, (LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast)) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = LayoutLMv3Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer()) + processor.save_pretrained(self.tmpdirname) + + # slow tokenizer + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30) + + processor = LayoutLMv3Processor.from_pretrained( + self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, LayoutLMv3Tokenizer) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) + + # fast tokenizer + tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30) + + processor = LayoutLMv3Processor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, LayoutLMv3TokenizerFast) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) + + +# different use cases tests +@require_torch +@require_pytesseract +class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): + @cached_property + def get_images(self): + # we verify our implementation on 2 document images from the DocVQA dataset + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test") + + image_1 = Image.open(ds[0]["file"]).convert("RGB") + image_2 = Image.open(ds[1]["file"]).convert("RGB") + + return image_1, image_2 + + @cached_property + def get_tokenizers(self): + slow_tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False) + fast_tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False) + return [slow_tokenizer, fast_tokenizer] + + @slow + def test_processor_case_1(self): + # case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True + + feature_extractor = LayoutLMv3FeatureExtractor() + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # not batched + input_feat_extract = feature_extractor(images[0], return_tensors="pt") + input_processor = processor(images[0], return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify image + self.assertAlmostEqual( + input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2 + ) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + input_feat_extract = feature_extractor(images, return_tensors="pt") + input_processor = processor(images, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify images + self.assertAlmostEqual( + input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2 + ) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC’s Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITC’s value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITC’s brands national assets, adding to India’s competitiveness. It is ITC’s aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + @slow + def test_processor_case_2(self): + # case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False + + feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # not batched + words = ["hello", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt") + + # verify keys + expected_keys = ["input_ids", "bbox", "attention_mask", "pixel_values"] + actual_keys = list(input_processor.keys()) + for key in expected_keys: + self.assertIn(key, actual_keys) + + # verify input_ids + expected_decoding = " hello world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = " hello world" + decoding = processor.decode(input_processor.input_ids[0].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [ + [0, 0, 0, 0], + [3, 2, 5, 1], + [6, 7, 4, 2], + [3, 9, 2, 4], + [1, 1, 2, 3], + [1, 1, 2, 3], + [0, 0, 0, 0], + ] + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + @slow + def test_processor_case_3(self): + # case 3: token classification (training), apply_ocr=False + + feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # not batched + words = ["weirdly", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + word_labels = [1, 2] + input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = " weirdly world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify labels + expected_labels = [-100, 1, -100, 2, -100] + self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels) + + # batched + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + word_labels = [[1, 2], [6, 3, 10, 2]] + input_processor = processor( + images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt" + ) + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = " my name is niels" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [ + [0, 0, 0, 0], + [3, 2, 5, 1], + [6, 7, 4, 2], + [3, 9, 2, 4], + [1, 1, 2, 3], + [1, 1, 2, 3], + [0, 0, 0, 0], + ] + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + # verify labels + expected_labels = [-100, 6, 3, 10, 2, -100, -100] + self.assertListEqual(input_processor.labels[1].tolist(), expected_labels) + + @slow + def test_processor_case_4(self): + # case 4: visual question answering (inference), apply_ocr=True + + feature_extractor = LayoutLMv3FeatureExtractor() + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # not batched + question = "What's his name?" + input_processor = processor(images[0], question, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + questions = ["How old is he?", "what's the time"] + input_processor = processor( + images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt" + ) + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + # fmt: off + expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [74, 136, 161, 158], [0, 0, 0, 0]] # noqa: E231 + # fmt: on + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + @slow + def test_processor_case_5(self): + # case 5: visual question answering (inference), apply_ocr=False + + feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + + # not batched + question = "What's his name?" + words = ["hello", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + input_processor = processor(images[0], question, words, boxes, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = " What's his name? hello world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + questions = ["How old is he?", "what's the time"] + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(list(input_processor.keys())) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = " How old is he? hello world" + decoding = processor.decode(input_processor.input_ids[0].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + expected_decoding = " what's the time my name is niels" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [0, 0, 0, 0]] + self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox) diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py new file mode 100644 index 0000000000000..ae12129e787f9 --- /dev/null +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -0,0 +1,2345 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +import re +import shutil +import tempfile +import unittest +from typing import List + +from transformers import AddedToken, LayoutLMv3TokenizerFast, SpecialTokensMixin, is_tf_available, is_torch_available +from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES, LayoutLMv3Tokenizer +from transformers.testing_utils import is_pt_tf_cross_test, require_pandas, require_tokenizers, require_torch, slow + +from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin, merge_model_tokenizer_mappings + + +@require_tokenizers +@require_pandas +class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = LayoutLMv3Tokenizer + rust_tokenizer_class = LayoutLMv3TokenizerFast + test_rust_tokenizer = True + # determined by the tokenization algortihm and the way it's decoded by the fast tokenizers + space_between_special_tokens = False + test_seq2seq = False + from_pretrained_kwargs = {"cls_token": ""} + + def get_words_and_boxes(self): + words = ["lower", "newer"] + boxes = [[423, 237, 440, 251], [427, 272, 441, 287]] + + return words, boxes + + def get_words_and_boxes_batch(self): + words = [["lower", "newer"], ["new", "low"]] + boxes = [ + [[423, 237, 440, 251], [427, 272, 441, 287]], + [[961, 885, 992, 912], [256, 38, 330, 58]], + ] + + return words, boxes + + def get_question_words_and_boxes(self): + question = "what's his name?" + words = ["lower", "newer"] + boxes = [[423, 237, 440, 251], [427, 272, 441, 287]] + + return question, words, boxes + + def get_question_words_and_boxes_batch(self): + questions = ["what's his name?", "how is he called?"] + words = [["lower", "newer"], ["newer", "lower"]] + boxes = [ + [[423, 237, 440, 251], [427, 272, 441, 287]], + [[256, 38, 330, 58], [256, 38, 330, 58]], + ] + + return questions, words, boxes + + def setUp(self): + super().setUp() + + # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt + vocab = [ + "l", + "o", + "w", + "e", + "r", + "s", + "t", + "i", + "d", + "n", + "\u0120", + "\u0120l", + "\u0120n", + "\u0120lo", + "\u0120low", + "er", + "\u0120lowest", + "\u0120newer", + "\u0120wider", + "", + ] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""] + self.special_tokens_map = {"unk_token": ""} + + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(vocab_tokens) + "\n") + with open(self.merges_file, "w", encoding="utf-8") as fp: + fp.write("\n".join(merges)) + + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) + + def get_rust_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return LayoutLMv3TokenizerFast.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self, tokenizer): + input_text = "lower newer" + output_text = "lower newer" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) + text = "lower newer" + bpe_tokens = ["Ġlow", "er", "Ġ", "n", "e", "w", "er"] + tokens = tokenizer.tokenize(text) # , add_prefix_space=True) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [tokenizer.unk_token] + input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19] + self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutlmv3-base") + + question, words, boxes = self.get_question_words_and_boxes() + + text = tokenizer.encode( + question.split(), + boxes=[tokenizer.pad_token_box for _ in range(len(question.split()))], + add_special_tokens=False, + ) + text_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_pair == [0] + text + [2] + [2] + text_2 + [2] + + def test_add_special_tokens(self): + tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + special_token = "[SPECIAL_TOKEN]" + special_token_box = [1000, 1000, 1000, 1000] + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode( + [special_token], boxes=[special_token_box], add_special_tokens=False + ) + self.assertEqual(len(encoded_special_token), 1) + + decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True) + self.assertTrue(special_token not in decoded) + + def test_add_tokens_tokenizer(self): + tokenizers: List[LayoutLMv3Tokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + self.assertNotEqual(vocab_size, 0) + + # We usually have added tokens from the start in tests because our vocab fixtures are + # smaller than the original vocabs - let's not assert this + # self.assertEqual(vocab_size, all_size) + + new_toks = ["aaaaa", "bbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) + + words = "aaaaa bbbbbb low cccccccccdddddddd l".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + + new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) + + self.assertNotEqual(vocab_size_3, 0) + self.assertEqual(vocab_size, vocab_size_3) + self.assertEqual(added_toks_2, len(new_toks_2)) + self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + + words = ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + tokens = tokenizer.encode( + words, + boxes=boxes, + add_special_tokens=False, + ) + + self.assertGreaterEqual(len(tokens), 6) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[0], tokens[1]) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokens[-3]) + self.assertEqual(tokens[0], tokenizer.eos_token_id) + self.assertEqual(tokens[-2], tokenizer.pad_token_id) + + @require_tokenizers + def test_encode_decode_with_spaces(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)] + tokenizer.add_tokens(new_toks) + input = "[ABC][DEF][ABC][DEF]" + if self.space_between_special_tokens: + output = "[ABC] [DEF] [ABC] [DEF]" + else: + output = input + encoded = tokenizer.encode(input.split(), boxes=boxes, add_special_tokens=False) + decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens) + self.assertIn(decoded, [output, output.lower()]) + + @unittest.skip("Not implemented") + def test_right_and_left_truncation(self): + pass + + def test_encode_plus_with_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + padding_size = 10 + padding_idx = tokenizer.pad_token_id + + encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_special_tokens_mask=True) + input_ids = encoded_sequence["input_ids"] + special_tokens_mask = encoded_sequence["special_tokens_mask"] + sequence_length = len(input_ids) + + # Test 'longest' and 'no_padding' don't do anything + tokenizer.padding_side = "right" + + not_padded_sequence = tokenizer.encode_plus( + words, + boxes=boxes, + padding=False, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + self.assertTrue(sequence_length == not_padded_sequence_length) + self.assertTrue(input_ids == not_padded_input_ids) + self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask) + + not_padded_sequence = tokenizer.encode_plus( + words, + boxes=boxes, + padding=False, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + self.assertTrue(sequence_length == not_padded_sequence_length) + self.assertTrue(input_ids == not_padded_input_ids) + self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask) + + # Test right padding + tokenizer.padding_side = "right" + + right_padded_sequence = tokenizer.encode_plus( + words, + boxes=boxes, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + right_padded_input_ids = right_padded_sequence["input_ids"] + + right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"] + right_padded_sequence_length = len(right_padded_input_ids) + + self.assertTrue(sequence_length + padding_size == right_padded_sequence_length) + self.assertTrue(input_ids + [padding_idx] * padding_size == right_padded_input_ids) + self.assertTrue(special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask) + + # Test left padding + tokenizer.padding_side = "left" + left_padded_sequence = tokenizer.encode_plus( + words, + boxes=boxes, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + left_padded_input_ids = left_padded_sequence["input_ids"] + left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"] + left_padded_sequence_length = len(left_padded_input_ids) + + self.assertTrue(sequence_length + padding_size == left_padded_sequence_length) + self.assertTrue([padding_idx] * padding_size + input_ids == left_padded_input_ids) + self.assertTrue([1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask) + + if "token_type_ids" in tokenizer.model_input_names: + token_type_ids = encoded_sequence["token_type_ids"] + left_padded_token_type_ids = left_padded_sequence["token_type_ids"] + right_padded_token_type_ids = right_padded_sequence["token_type_ids"] + + assert token_type_ids + [0] * padding_size == right_padded_token_type_ids + assert [0] * padding_size + token_type_ids == left_padded_token_type_ids + + if "attention_mask" in tokenizer.model_input_names: + attention_mask = encoded_sequence["attention_mask"] + right_padded_attention_mask = right_padded_sequence["attention_mask"] + left_padded_attention_mask = left_padded_sequence["attention_mask"] + + self.assertTrue(attention_mask + [0] * padding_size == right_padded_attention_mask) + self.assertTrue([0] * padding_size + attention_mask == left_padded_attention_mask) + + def test_internal_consistency(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + tokens = [] + for word in words: + tokens.extend(tokenizer.tokenize(word)) + ids = tokenizer.convert_tokens_to_ids(tokens) + ids_2 = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + self.assertListEqual(ids, ids_2) + + tokens_2 = tokenizer.convert_ids_to_tokens(ids) + self.assertNotEqual(len(tokens_2), 0) + text_2 = tokenizer.decode(ids) + self.assertIsInstance(text_2, str) + + output_text = " lower newer" + self.assertEqual(text_2, output_text) + + def test_mask_output(self): + tokenizers = self.get_tokenizers(fast=False, do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + if ( + tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer" + and "token_type_ids" in tokenizer.model_input_names + ): + information = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True) + sequences, mask = information["input_ids"], information["token_type_ids"] + self.assertEqual(len(sequences), len(mask)) + + def test_number_of_added_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + # test 1: single sequence + words, boxes = self.get_words_and_boxes() + + sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + attached_sequences = tokenizer.encode(words, boxes=boxes, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + self.assertEqual( + tokenizer.num_special_tokens_to_add(pair=False), len(attached_sequences) - len(sequences) + ) + + # test 2: two sequences + question, words, boxes = self.get_question_words_and_boxes() + + sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=False) + attached_sequences = tokenizer.encode(question, words, boxes=boxes, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + self.assertEqual( + tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences) + ) + + def test_padding_to_max_length(self): + """We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated""" + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + padding_idx = tokenizer.pad_token_id + + # Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode(words, boxes=boxes) + sequence_length = len(encoded_sequence) + # FIXME: the next line should be padding(max_length) to avoid warning + padded_sequence = tokenizer.encode( + words, boxes=boxes, max_length=sequence_length + padding_size, pad_to_max_length=True + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # Check that nothing is done when a maximum length is not specified + encoded_sequence = tokenizer.encode(words, boxes=boxes) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(words, boxes=boxes, pad_to_max_length=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + def test_padding(self, max_length=50): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id) + pad_token_id = tokenizer_p.pad_token_id + + # Encode - Simple input + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode(words, boxes=boxes, max_length=max_length, padding="max_length") + input_p = tokenizer_p.encode(words, boxes=boxes, max_length=max_length, padding="max_length") + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.encode(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode(words, boxes=boxes, padding=True) + self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id) + + # Encode - Pair input + question, words, boxes = self.get_question_words_and_boxes() + input_r = tokenizer_r.encode( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + input_p = tokenizer_p.encode( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length") + input_p = tokenizer_p.encode(question, words, boxes=boxes, max_length=max_length, padding="max_length") + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode(question, words, boxes=boxes, padding=True) + input_p = tokenizer_p.encode(question, words, boxes=boxes, padding="longest") + self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id) + + # Encode_plus - Simple input + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length") + input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=max_length, padding="max_length") + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + input_r = tokenizer_r.encode_plus(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode_plus(words, boxes=boxes, padding=True) + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + # Encode_plus - Pair input + question, words, boxes = self.get_question_words_and_boxes() + input_r = tokenizer_r.encode_plus( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + input_p = tokenizer_p.encode_plus( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + input_p = tokenizer_p.encode_plus( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus(question, words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode_plus(question, words, boxes=boxes, padding=True) + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + # Batch_encode_plus - Simple input + words, boxes = self.get_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + pad_to_max_length=True, + ) + input_p = tokenizer_p.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + pad_to_max_length=True, + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + padding="max_length", + ) + input_p = tokenizer_p.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + padding="max_length", + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + padding="longest", + ) + input_p = tokenizer_p.batch_encode_plus( + words, + boxes=boxes, + max_length=max_length, + padding=True, + ) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.batch_encode_plus(words, boxes=boxes, padding=True) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Batch_encode_plus - Pair input + questions, words, boxes = self.get_question_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + max_length=max_length, + truncation=True, + padding="max_length", + ) + input_p = tokenizer_p.batch_encode_plus( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + max_length=max_length, + truncation=True, + padding="max_length", + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + padding=True, + ) + input_p = tokenizer_p.batch_encode_plus( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + padding="longest", + ) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Using pad on single examples after tokenization + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode_plus(words, boxes=boxes) + input_r = tokenizer_r.pad(input_r) + + input_p = tokenizer_r.encode_plus(words, boxes=boxes) + input_p = tokenizer_r.pad(input_p) + + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + + # Using pad on single examples after tokenization + input_r = tokenizer_r.encode_plus(words, boxes=boxes) + input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") + + input_p = tokenizer_r.encode_plus(words, boxes=boxes) + input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + + # Using pad after tokenization + words, boxes = self.get_words_and_boxes_batch() + input_r = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + ) + input_r = tokenizer_r.pad(input_r) + + input_p = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + ) + input_p = tokenizer_r.pad(input_p) + + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Using pad after tokenization + words, boxes = self.get_words_and_boxes_batch() + input_r = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + ) + input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") + + input_p = tokenizer_r.batch_encode_plus( + words, + boxes=boxes, + ) + input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Test not batched + words, boxes = self.get_words_and_boxes() + encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test not batched pairs + question, words, boxes = self.get_question_words_and_boxes() + encoded_sequences_1 = tokenizer.encode_plus(words, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test batched + words, boxes = self.get_words_and_boxes_batch() + encoded_sequences_1 = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + def test_batch_encode_plus_batch_sequence_length(self): + # Tests that all encoded values have the correct size + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + encoded_sequences = [ + tokenizer.encode_plus(words_example, boxes=boxes_example) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus(words, is_pair=False, boxes=boxes, padding=False) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + maximum_length = len( + max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len) + ) + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences_padded = [ + tokenizer.encode_plus( + words_example, boxes=boxes_example, max_length=maximum_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + + encoded_sequences_batch_padded = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, padding=True + ) + self.assertListEqual( + encoded_sequences_padded, + self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded), + ) + + # check 'longest' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, padding=True + ) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding="longest" + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + # check 'no_padding' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, padding=False + ) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding=False + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + @unittest.skip("batch_encode_plus does not handle overflowing tokens.") + def test_batch_encode_plus_overflowing_tokens(self): + pass + + def test_batch_encode_plus_padding(self): + # Test that padded sequences are equivalent between batch_encode_plus and encode_plus + + # Right padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences = [ + tokenizer.encode_plus( + words_example, boxes=boxes_example, max_length=max_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + # Left padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.padding_side = "left" + words, boxes = self.get_words_and_boxes_batch() + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences = [ + tokenizer.encode_plus( + words_example, boxes=boxes_example, max_length=max_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus( + words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + def test_padding_to_multiple_of(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.pad_token is None: + self.skipTest("No padding token.") + else: + words, boxes = self.get_words_and_boxes() + + # empty_tokens = tokenizer([""], [[]], padding=True, pad_to_multiple_of=8) + normal_tokens = tokenizer(words, boxes=boxes, padding=True, pad_to_multiple_of=8) + # for key, value in empty_tokens.items(): + # self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + normal_tokens = tokenizer(words, boxes=boxes, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # Should also work with truncation + normal_tokens = tokenizer(words, boxes=boxes, padding=True, truncation=True, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # truncation to something which is not a multiple of pad_to_multiple_of raises an error + self.assertRaises( + ValueError, + tokenizer.__call__, + words, + boxes=boxes, + padding=True, + truncation=True, + max_length=12, + pad_to_multiple_of=8, + ) + + def test_tokenizer_slow_store_full_signature(self): + signature = inspect.signature(self.tokenizer_class.__init__) + tokenizer = self.get_tokenizer() + + for parameter_name, parameter in signature.parameters.items(): + if parameter.default != inspect.Parameter.empty: + self.assertIn(parameter_name, tokenizer.init_kwargs) + + def test_build_inputs_with_special_tokens(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + # Input tokens id + words, boxes = self.get_words_and_boxes() + input_simple = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False) + input_pair = tokenizer_p.encode(words, boxes=boxes, add_special_tokens=False) + + # Generate output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple) + self.assertEqual(output_p, output_r) + + # Generate pair output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) + self.assertEqual(output_p, output_r) + + def test_special_tokens_mask_input_pairs(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus( + words, + boxes=boxes, + add_special_tokens=True, + return_special_tokens_mask=True, + # add_prefix_space=False, + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [ + (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special) + ] + filtered_sequence = [x for x in filtered_sequence if x is not None] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_special_tokens_mask(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + # Testing single inputs + encoded_sequence = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus( + words, boxes=boxes, add_special_tokens=True, return_special_tokens_mask=True + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + self.assertNotEqual(tokenizer.model_max_length, 42) + + # Now let's start the test + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Isolate this from the other tests because we save additional tokens/etc + words, boxes = self.get_words_and_boxes() + tmpdirname = tempfile.mkdtemp() + + before_tokens = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + before_vocab = tokenizer.get_vocab() + tokenizer.save_pretrained(tmpdirname) + + after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname) + after_tokens = after_tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + after_vocab = after_tokenizer.get_vocab() + self.assertListEqual(before_tokens, after_tokens) + self.assertDictEqual(before_vocab, after_vocab) + + shutil.rmtree(tmpdirname) + + def test_right_and_left_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + sequence = "Sequence" + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequence) + + padding_idx = tokenizer.pad_token_id + + # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode(words, boxes=boxes) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode( + words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "left" + encoded_sequence = tokenizer.encode(words, boxes=boxes) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode( + words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert [padding_idx] * padding_size + encoded_sequence == padded_sequence + + # RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding' + encoded_sequence = tokenizer.encode(words, boxes=boxes) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(words, boxes=boxes, padding=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding="longest") + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(words, boxes=boxes) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode(words, boxes=boxes, padding=False) + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + def test_token_type_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + # test 1: single sequence + words, boxes = self.get_words_and_boxes() + + output = tokenizer(words, boxes=boxes, return_token_type_ids=True) + + # Assert that the token type IDs have the same length as the input IDs + self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"])) + + # Assert that the token type IDs have the same length as the attention mask + self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"])) + + self.assertIn(0, output["token_type_ids"]) + self.assertNotIn(1, output["token_type_ids"]) + + # test 2: two sequences (question + words) + question, words, boxes = self.get_question_words_and_boxes() + + output = tokenizer(question, words, boxes, return_token_type_ids=True) + + # Assert that the token type IDs have the same length as the input IDs + self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"])) + + # Assert that the token type IDs have the same length as the attention mask + self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"])) + + self.assertIn(0, output["token_type_ids"]) + + def test_offsets_mapping(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + text = ["a", "wonderful", "test"] + boxes = [[1, 8, 12, 20] for _ in range(len(text))] + + # No pair + tokens_with_offsets = tokenizer_r.encode_plus( + text, + boxes=boxes, + return_special_tokens_mask=True, + return_offsets_mapping=True, + add_special_tokens=True, + ) + added_tokens = tokenizer_r.num_special_tokens_to_add(False) + offsets = tokens_with_offsets["offset_mapping"] + + # Assert there is the same number of tokens and offsets + self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"])) + + # Assert there is online added_tokens special_tokens + self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) + + # Pairs + text = "what's his name" + pair = ["a", "wonderful", "test"] + boxes = [[1, 8, 12, 20] for _ in range(len(pair))] + tokens_with_offsets = tokenizer_r.encode_plus( + text, + pair, + boxes=boxes, + return_special_tokens_mask=True, + return_offsets_mapping=True, + add_special_tokens=True, + ) + added_tokens = tokenizer_r.num_special_tokens_to_add(True) + offsets = tokens_with_offsets["offset_mapping"] + + # Assert there is the same number of tokens and offsets + self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"])) + + # Assert there is online added_tokens special_tokens + self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) + + @require_torch + @slow + def test_torch_encode_plus_sent_to_model(self): + import torch + + from transformers import MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # Make sure the model contains at least the full vocabulary size in its embedding matrix + is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight") + assert ( + (model.get_input_embeddings().weight.shape[0] >= len(tokenizer)) + if is_using_common_embeddings + else True + ) + + # Build sequence + words, boxes = self.get_words_and_boxes() + encoded_sequence = tokenizer.encode_plus(words, boxes=boxes, return_tensors="pt") + batch_encoded_sequence = tokenizer.batch_encode_plus( + [words, words], boxes=[boxes, boxes], return_tensors="pt" + ) + + # We add dummy pixel_values keys (as LayoutLMv3 actually also requires a feature extractor + # to prepare the image input) + encoded_sequence["pixel_values"] = torch.randn(1, 3, 224, 224) + batch_encoded_sequence["pixel_values"] = torch.randn(2, 3, 224, 224) + + # This should not fail + with torch.no_grad(): # saves some time + model(**encoded_sequence) + model(**batch_encoded_sequence) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + words, boxes = self.get_words_and_boxes() + + ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=True) + rust_ids = rust_tokenizer.encode(words, boxes=boxes, add_special_tokens=True) + self.assertListEqual(ids, rust_ids) + + def test_tokenization_python_rust_equals(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + words, boxes = self.get_words_and_boxes() + + # Ensure basic input match + input_p = tokenizer_p.encode_plus(words, boxes=boxes) + input_r = tokenizer_r.encode_plus(words, boxes=boxes) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key]) + + input_pairs_p = tokenizer_p.encode_plus(words, boxes=boxes) + input_pairs_r = tokenizer_r.encode_plus(words, boxes=boxes) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key]) + + words = ["hello" for _ in range(1000)] + boxes = [[1000, 1000, 1000, 1000] for _ in range(1000)] + + # Ensure truncation match + input_p = tokenizer_p.encode_plus(words, boxes=boxes, max_length=512, truncation=True) + input_r = tokenizer_r.encode_plus(words, boxes=boxes, max_length=512, truncation=True) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key]) + + # Ensure truncation with stride match + input_p = tokenizer_p.encode_plus( + words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True + ) + input_r = tokenizer_r.encode_plus( + words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True + ) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key][0]) + + def test_embeded_special_tokens(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + words, boxes = self.get_words_and_boxes() + tokens_r = tokenizer_r.encode_plus( + words, + boxes=boxes, + add_special_tokens=True, + ) + tokens_p = tokenizer_p.encode_plus( + words, + boxes=boxes, + add_special_tokens=True, + ) + + for key in tokens_p.keys(): + self.assertEqual(tokens_r[key], tokens_p[key]) + + if "token_type_ids" in tokens_r: + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"]) + tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"]) + self.assertSequenceEqual(tokens_r, tokens_p) + + def test_compare_add_special_tokens(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False) + + words, boxes = self.get_words_and_boxes() + # tokenize() + no_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=False) + with_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=True) + self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add) + + # encode() + no_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.encode(words, boxes=boxes, add_special_tokens=True) + self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add) + + # encode_plus() + no_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.encode_plus(words, boxes=boxes, add_special_tokens=True) + for key in no_special_tokens.keys(): + self.assertEqual( + len(no_special_tokens[key]), + len(with_special_tokens[key]) - simple_num_special_tokens_to_add, + ) + + # # batch_encode_plus + words, boxes = self.get_words_and_boxes_batch() + + no_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.batch_encode_plus(words, boxes=boxes, add_special_tokens=True) + for key in no_special_tokens.keys(): + for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]): + self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add) + + @slow + def test_layoutlmv3_truncation_integration_test(self): + words, boxes = self.get_words_and_boxes() + + tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", model_max_length=512) + + for i in range(12, 512): + new_encoded_inputs = tokenizer.encode(words, boxes=boxes, max_length=i, truncation=True) + + # Ensure that the input IDs are less than the max length defined. + self.assertLessEqual(len(new_encoded_inputs), i) + + tokenizer.model_max_length = 20 + new_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True) + dropped_encoded_inputs = tokenizer.encode(words, boxes=boxes, truncation=True) + + # Ensure that the input IDs are still truncated when no max_length is specified + self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs) + self.assertLessEqual(len(new_encoded_inputs), 20) + + @is_pt_tf_cross_test + def test_batch_encode_plus_tensors(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + # A Tensor cannot be build by sequences which are not the same size + self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="pt") + self.assertRaises(ValueError, tokenizer.batch_encode_plus, words, boxes=boxes, return_tensors="tf") + + if tokenizer.pad_token_id is None: + self.assertRaises( + ValueError, + tokenizer.batch_encode_plus, + words, + boxes=boxes, + padding=True, + return_tensors="pt", + ) + self.assertRaises( + ValueError, + tokenizer.batch_encode_plus, + words, + boxes=boxes, + padding="longest", + return_tensors="tf", + ) + else: + pytorch_tensor = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True, return_tensors="pt") + tensorflow_tensor = tokenizer.batch_encode_plus( + words, boxes=boxes, padding="longest", return_tensors="tf" + ) + encoded_sequences = tokenizer.batch_encode_plus(words, boxes=boxes, padding=True) + + for key in encoded_sequences.keys(): + pytorch_value = pytorch_tensor[key].tolist() + tensorflow_value = tensorflow_tensor[key].numpy().tolist() + encoded_value = encoded_sequences[key] + + self.assertEqual(pytorch_value, tensorflow_value, encoded_value) + + def test_sequence_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + if not tokenizer.is_fast: + continue + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0 = "Test this method." + seq_1 = ["With", "these", "inputs."] + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(seq_1))] + + # We want to have sequence 0 and sequence 1 are tagged + # respectively with 0 and 1 token_ids + # (regardless of whether the model use token type ids) + # We use this assumption in the QA pipeline among other place + output = tokenizer(seq_0.split(), boxes=boxes) + self.assertIn(0, output.sequence_ids()) + + output = tokenizer(seq_0, seq_1, boxes=boxes) + self.assertIn(0, output.sequence_ids()) + self.assertIn(1, output.sequence_ids()) + + if tokenizer.num_special_tokens_to_add(pair=True): + self.assertIn(None, output.sequence_ids()) + + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + added_tokens = [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + words = "Hey this is a token".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + r_output = tokenizer_r.encode(words, boxes=boxes) + + special_token_id = tokenizer_r.encode( + [""], boxes=[1000, 1000, 1000, 1000], add_special_tokens=False + )[0] + + self.assertTrue(special_token_id in r_output) + + if self.test_slow_tokenizer: + tokenizer_cr = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + words = "Hey this is a token".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + p_output = tokenizer_p.encode(words, boxes=boxes) + cr_output = tokenizer_cr.encode(words, boxes=boxes) + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in cr_output) + + def test_training_new_tokenizer(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_rust_tokenizer() + new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100) + + # Test we can use the new tokenizer with something not seen during training + text = [["this", "is", "the"], ["how", "are", "you"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8], [1, 3, 4, 8]], [[5, 6, 7, 8], [4, 5, 6, 7], [3, 9, 2, 7]]] + inputs = new_tokenizer(text, boxes=boxes) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = " this is the" + + if tokenizer.backend_tokenizer.normalizer is not None: + expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) + self.assertEqual(expected_result, decoded_input) + + # We check that the parameters of the tokenizer remained the same + # Check we have the same number of added_tokens for both pair and non-pair inputs. + self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False)) + self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True)) + + # Check we have the correct max_length for both pair and non-pair inputs. + self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence) + self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair) + + # Assert the set of special tokens match as we didn't ask to change them + self.assertSequenceEqual( + tokenizer.all_special_tokens_extended, + new_tokenizer.all_special_tokens_extended, + ) + + self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map) + + def test_training_new_tokenizer_with_special_tokens_change(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_rust_tokenizer() + # Test with a special tokens map + class_signature = inspect.signature(tokenizer.__class__) + if "cls_token" in class_signature.parameters: + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: ""} + ) + cls_id = new_tokenizer.get_vocab()[""] + self.assertEqual(new_tokenizer.cls_token, "") + self.assertEqual(new_tokenizer.cls_token_id, cls_id) + + # Create a new mapping from the special tokens defined in the original tokenizer + special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() + special_tokens_list.remove("additional_special_tokens") + special_tokens_map = {} + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is not None: + special_token = getattr(tokenizer, token) + special_tokens_map[special_token] = f"{special_token}a" + + # Train new tokenizer + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map + ) + + # Check the changes + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is None: + continue + special_token = getattr(tokenizer, token) + if special_token in special_tokens_map: + new_special_token = getattr(new_tokenizer, token) + self.assertEqual(special_tokens_map[special_token], new_special_token) + + new_id = new_tokenizer.get_vocab()[new_special_token] + self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id) + + # Check if the AddedToken / string format has been kept + for special_token in tokenizer.all_special_tokens_extended: + if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + elif isinstance(special_token, AddedToken): + # The special token must appear in the list of the new tokenizer as an object of type AddedToken with + # the same parameters as the old AddedToken except the content that the user has requested to change. + special_token_str = special_token.content + new_special_token_str = special_tokens_map[special_token_str] + + find = False + for candidate in new_tokenizer.all_special_tokens_extended: + if ( + isinstance(candidate, AddedToken) + and candidate.content == new_special_token_str + and candidate.lstrip == special_token.lstrip + and candidate.rstrip == special_token.rstrip + and candidate.normalized == special_token.normalized + and candidate.single_word == special_token.single_word + ): + find = True + break + self.assertTrue( + find, + f"'{new_special_token_str}' doesn't appear in the list " + f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as " + f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}", + ) + elif special_token not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + + else: + # The special token must appear in the list of the new tokenizer as an object of type string. + self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended) + + # Test we can use the new tokenizer with something not seen during training + words = [["this", "is"], ["hello", "🤗"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]] + inputs = new_tokenizer(words, boxes=boxes) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = " this is" + + if tokenizer.backend_tokenizer.normalizer is not None: + expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) + self.assertEqual(expected_result, decoded_input) + + def test_prepare_for_model(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + # only test prepare_for_model for the slow tokenizer + if tokenizer.__class__.__name__ == "LayoutLMv3TokenizerFast": + continue + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + prepared_input_dict = tokenizer.prepare_for_model(words, boxes=boxes, add_special_tokens=True) + + input_dict = tokenizer.encode_plus(words, boxes=boxes, add_special_tokens=True) + + self.assertEqual(input_dict, prepared_input_dict) + + def test_padding_different_model_input_name(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id) + pad_token_id = tokenizer_p.pad_token_id + + words, boxes = self.get_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus(words, boxes=boxes) + input_p = tokenizer_r.batch_encode_plus(words, boxes=boxes) + + # rename encoded batch to "inputs" + input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]] + del input_r[tokenizer_r.model_input_names[0]] + + input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]] + del input_p[tokenizer_p.model_input_names[0]] + + # Renaming `input_ids` to `inputs` + tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:] + tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:] + + input_r = tokenizer_r.pad(input_r, padding="longest") + input_p = tokenizer_r.pad(input_p, padding="longest") + + max_length = len(input_p["inputs"][0]) + self.assert_batch_padded_input_match( + input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs" + ) + + def test_batch_encode_dynamic_overflowing(self): + """ + When calling batch_encode with multiple sequences, it can return different number of + overflowing encoding for each sequence: + [ + Sequence 1: [Encoding 1, Encoding 2], + Sequence 2: [Encoding 1], + Sequence 3: [Encoding 1, Encoding 2, ... Encoding N] + ] + This needs to be padded so that it can represented as a tensor + """ + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): + + if is_torch_available(): + returned_tensor = "pt" + elif is_tf_available(): + returned_tensor = "tf" + else: + returned_tensor = "jax" + + # Single example + words = ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"] + boxes = [[i, i, i, i] for i in range(len(words))] + tokens = tokenizer.encode_plus( + words, + boxes=boxes, + max_length=6, + padding=True, + truncation=True, + return_tensors=returned_tensor, + return_overflowing_tokens=True, + ) + + for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()): + if key != "bbox": + self.assertEqual(len(tokens[key].shape), 2) + else: + self.assertEqual(len(tokens[key].shape), 3) + + # Batch of examples + # For these 2 examples, 3 training examples will be created + words_batched = [ + ["HuggingFace", "is", "solving", "NLP", "one", "commit", "at", "a", "time"], + ["Very", "tiny", "input"], + ] + boxes_batched = [[[i, i, i, i] for i in range(len(words_item))] for words_item in words_batched] + tokens = tokenizer.batch_encode_plus( + words_batched, + boxes=boxes_batched, + max_length=6, + padding=True, + truncation="only_first", + return_tensors=returned_tensor, + return_overflowing_tokens=True, + ) + + for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()): + if key != "bbox": + self.assertEqual(len(tokens[key].shape), 2) + self.assertEqual(tokens[key].shape[-1], 6) + else: + self.assertEqual(len(tokens[key].shape), 3) + self.assertEqual(tokens[key].shape[-1], 4) + + @unittest.skip("TO DO: overwrite this very extensive test.") + def test_alignement_methods(self): + pass + + def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5): + toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))] + toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks)) + toks = list( + filter( + lambda t: [t[0]] + == tokenizer.encode(t[1].split(" "), boxes=len(t[1]) * [[1, 1, 1, 1]], add_special_tokens=False), + toks, + ) + ) + if max_length is not None and len(toks) > max_length: + toks = toks[:max_length] + if min_length is not None and len(toks) < min_length and len(toks) > 0: + while len(toks) < min_length: + toks = toks + toks + # toks_str = [t[1] for t in toks] + toks_ids = [t[0] for t in toks] + + # Ensure consistency + output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False) + if " " not in output_txt and len(toks_ids) > 1: + output_txt = ( + tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False) + + " " + + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False) + ) + if with_prefix_space: + output_txt = " " + output_txt + words = output_txt.split(" ") + boxes = [[i, i, i, i] for i in range(len(words))] + output_ids = tokenizer.encode(words, boxes=boxes, add_special_tokens=False) + + return words, boxes, output_ids + + def test_added_token_with_space_before(self): + + tokenizer_s = self.get_tokenizer() + tokenizer_f = self.get_rust_tokenizer() + + tokens_to_add = ["AAA", "bbb"] + + words_with_space = [f" {token}" for token in tokens_to_add + tokenizer_s.unique_no_split_tokens] + words_without_space = tokens_to_add + tokenizer_s.unique_no_split_tokens + boxes = [[i, i, i, i] for i in range(len(words_with_space))] + + tokens_to_add_formated = [ + AddedToken(token, rstrip=True, lstrip=True, single_word=False) for token in tokens_to_add + ] + tokenizer_s.add_tokens(tokens_to_add_formated) + tokenizer_f.add_tokens(tokens_to_add_formated) + + ids_s = tokenizer_s(words_with_space, boxes=boxes).input_ids + ids_f = tokenizer_f(words_with_space, boxes=boxes).input_ids + + tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s) + tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f) + + ids_s = tokenizer_s(words_without_space, boxes=boxes).input_ids + ids_f = tokenizer_f(words_without_space, boxes=boxes).input_ids + + tokens_s = tokenizer_s.convert_ids_to_tokens(ids_s) + tokens_f = tokenizer_f.convert_ids_to_tokens(ids_f) + + self.assertEqual(tokens_s, tokens_f) + + def test_maximum_encoding_length_pair_input(self): + tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Build a sequence from our model's vocabulary + stride = 2 + seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20) + question_0 = " ".join(map(str, seq_0)) + if len(ids) <= 2 + stride: + seq_0 = (seq_0 + " ") * (2 + stride) + ids = None + + seq0_tokens = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False) + seq0_input_ids = seq0_tokens["input_ids"] + + self.assertGreater(len(seq0_input_ids), 2 + stride) + question_1 = "This is another sentence to be encoded." + seq_1 = ["what", "a", "weird", "test", "weirdly", "weird"] + boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)] + seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False) + if abs(len(seq0_input_ids) - len(seq1_tokens["input_ids"])) <= 2: + seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"] + seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False) + seq_1 = seq_1.split(" ") + boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)] + seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False) + seq1_input_ids = seq1_tokens["input_ids"] + + self.assertGreater(len(seq1_input_ids), 2 + stride) + + smallest = seq1_input_ids if len(seq0_input_ids) > len(seq1_input_ids) else seq0_input_ids + + # We are not using the special tokens - a bit too hard to test all the tokenizers with this + # TODO try this again later + sequence = tokenizer( + question_0, seq_1, boxes=boxes_1, add_special_tokens=False + ) # , add_prefix_space=False) + + # Test with max model input length + model_max_length = tokenizer.model_max_length + self.assertEqual(model_max_length, 100) + seq_2 = seq_0 * model_max_length + question_2 = " ".join(map(str, seq_2)) + boxes_2 = boxes_0 * model_max_length + self.assertGreater(len(seq_2), model_max_length) + + sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False) + total_length1 = len(sequence1["input_ids"]) + sequence2 = tokenizer(question_2, seq_1, boxes=boxes_1, add_special_tokens=False) + total_length2 = len(sequence2["input_ids"]) + self.assertLess(total_length1, model_max_length, "Issue with the testing sequence, please update it.") + self.assertGreater( + total_length2, model_max_length, "Issue with the testing sequence, please update it." + ) + + # Simple + padding_strategies = ( + [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False] + ) + for padding_state in padding_strategies: + with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"): + for truncation_state in [True, "longest_first", "only_first"]: + with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"): + output = tokenizer( + question_2, + seq_1, + boxes=boxes_1, + padding=padding_state, + truncation=truncation_state, + ) + self.assertEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(output["bbox"]), model_max_length) + + output = tokenizer( + [question_2], + [seq_1], + boxes=[boxes_1], + padding=padding_state, + truncation=truncation_state, + ) + self.assertEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(output["bbox"][0]), model_max_length) + + # Simple + output = tokenizer( + question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation="only_second" + ) + self.assertEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(output["bbox"]), model_max_length) + + output = tokenizer( + [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation="only_second" + ) + self.assertEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(output["bbox"][0]), model_max_length) + + # Simple with no truncation + # Reset warnings + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer( + question_1, seq_2, boxes=boxes_2, padding=padding_state, truncation=False + ) + self.assertNotEqual(len(output["input_ids"]), model_max_length) + self.assertNotEqual(len(output["bbox"]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer( + [question_1], [seq_2], boxes=[boxes_2], padding=padding_state, truncation=False + ) + self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + self.assertNotEqual(len(output["bbox"][0]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation + truncated_first_sequence = ( + tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][:-2] + + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"] + ) + truncated_second_sequence = ( + tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"] + + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][:-2] + ) + truncated_longest_sequence = ( + truncated_first_sequence + if len(seq0_input_ids) > len(seq1_input_ids) + else truncated_second_sequence + ) + + overflow_first_sequence = ( + tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"][-(2 + stride) :] + + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"] + ) + overflow_second_sequence = ( + tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)["input_ids"] + + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][-(2 + stride) :] + ) + overflow_longest_sequence = ( + overflow_first_sequence if len(seq0_input_ids) > len(seq1_input_ids) else overflow_second_sequence + ) + + bbox_first = [[0, 0, 0, 0]] * (len(seq0_input_ids) - 2) + bbox_first_sequence = bbox_first + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"] + overflowing_token_bbox_first_sequence_slow = [[0, 0, 0, 0]] * (2 + stride) + overflowing_token_bbox_first_sequence_fast = [[0, 0, 0, 0]] * (2 + stride) + tokenizer( + seq_1, boxes=boxes_1, add_special_tokens=False + )["bbox"] + + bbox_second = [[0, 0, 0, 0]] * len(seq0_input_ids) + bbox_second_sequence = ( + bbox_second + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"][:-2] + ) + overflowing_token_bbox_second_sequence_slow = tokenizer( + seq_1, boxes=boxes_1, add_special_tokens=False + )["bbox"][-(2 + stride) :] + overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq0_input_ids) + tokenizer( + seq_1, boxes=boxes_1, add_special_tokens=False + )["bbox"][-(2 + stride) :] + + bbox_longest_sequence = ( + bbox_first_sequence if len(seq0_tokens) > len(seq1_tokens) else bbox_second_sequence + ) + overflowing_token_bbox_longest_sequence_fast = ( + overflowing_token_bbox_first_sequence_fast + if len(seq0_tokens) > len(seq1_tokens) + else overflowing_token_bbox_second_sequence_fast + ) + + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, LayoutLMv3TokenizerFast): + information = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation="longest_first", + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + truncated_sequence = information["input_ids"][0] + overflowing_tokens = information["input_ids"][1] + bbox = information["bbox"][0] + overflowing_bbox = information["bbox"][1] + self.assertEqual(len(information["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_longest_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest)) + self.assertEqual(overflowing_tokens, overflow_longest_sequence) + self.assertEqual(bbox, bbox_longest_sequence) + + self.assertEqual(len(overflowing_bbox), 2 + stride + len(smallest)) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast) + else: + # No overflowing tokens when using 'longest' in python tokenizers + with self.assertRaises(ValueError) as context: + information = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation="longest_first", + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + + self.assertTrue( + context.exception.args[0].startswith( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + ) + + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, LayoutLMv3TokenizerFast): + information = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + truncated_sequence = information["input_ids"][0] + overflowing_tokens = information["input_ids"][1] + bbox = information["bbox"][0] + overflowing_bbox = information["bbox"][1] + self.assertEqual(len(information["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_longest_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest)) + self.assertEqual(overflowing_tokens, overflow_longest_sequence) + self.assertEqual(bbox, bbox_longest_sequence) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_longest_sequence_fast) + else: + # No overflowing tokens when using 'longest' in python tokenizers + with self.assertRaises(ValueError) as context: + information = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + + self.assertTrue( + context.exception.args[0].startswith( + "Not possible to return overflowing tokens for pair of sequences with the " + "`longest_first`. Please select another truncation strategy than `longest_first`, " + "for instance `only_second` or `only_first`." + ) + ) + + information_first_truncated = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation="only_first", + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, LayoutLMv3TokenizerFast): + truncated_sequence = information_first_truncated["input_ids"][0] + overflowing_tokens = information_first_truncated["input_ids"][1] + bbox = information_first_truncated["bbox"][0] + overflowing_bbox = information_first_truncated["bbox"][0] + self.assertEqual(len(information_first_truncated["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_first_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_input_ids)) + self.assertEqual(overflowing_tokens, overflow_first_sequence) + self.assertEqual(bbox, bbox_first_sequence) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_fast) + else: + truncated_sequence = information_first_truncated["input_ids"] + overflowing_tokens = information_first_truncated["overflowing_tokens"] + overflowing_bbox = information_first_truncated["overflowing_token_boxes"] + bbox = information_first_truncated["bbox"] + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_first_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, seq0_input_ids[-(2 + stride) :]) + self.assertEqual(bbox, bbox_first_sequence) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_slow) + + information_second_truncated = tokenizer( + question_0, + seq_1, + boxes=boxes_1, + max_length=len(sequence["input_ids"]) - 2, + add_special_tokens=False, + stride=stride, + truncation="only_second", + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, LayoutLMv3TokenizerFast): + truncated_sequence = information_second_truncated["input_ids"][0] + overflowing_tokens = information_second_truncated["input_ids"][1] + bbox = information_second_truncated["bbox"][0] + overflowing_bbox = information_second_truncated["bbox"][1] + + self.assertEqual(len(information_second_truncated["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_second_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_input_ids)) + self.assertEqual(overflowing_tokens, overflow_second_sequence) + self.assertEqual(bbox, bbox_second_sequence) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_fast) + else: + truncated_sequence = information_second_truncated["input_ids"] + overflowing_tokens = information_second_truncated["overflowing_tokens"] + bbox = information_second_truncated["bbox"] + overflowing_bbox = information_second_truncated["overflowing_token_boxes"] + + self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2) + self.assertEqual(truncated_sequence, truncated_second_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, seq1_input_ids[-(2 + stride) :]) + self.assertEqual(bbox, bbox_second_sequence) + self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_slow) + + def test_maximum_encoding_length_single_input(self): + tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0, boxes_0, ids = self.get_clean_sequence(tokenizer, max_length=20) + + sequence = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False) + total_length = len(sequence["input_ids"]) + + self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short") + + # Test with max model input length + model_max_length = tokenizer.model_max_length + self.assertEqual(model_max_length, 100) + seq_1 = seq_0 * model_max_length + boxes_1 = boxes_0 * model_max_length + sequence1 = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False) + total_length1 = len(sequence1["input_ids"]) + self.assertGreater( + total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short" + ) + + # Simple + padding_strategies = ( + [False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False] + ) + for padding_state in padding_strategies: + with self.subTest(f"Padding: {padding_state}"): + for truncation_state in [True, "longest_first", "only_first"]: + with self.subTest(f"Truncation: {truncation_state}"): + output = tokenizer( + seq_1, + boxes=boxes_1, + padding=padding_state, + truncation=truncation_state, + ) + + self.assertEqual(len(output["input_ids"]), model_max_length) + self.assertEqual(len(output["bbox"]), model_max_length) + + output = tokenizer( + [seq_1], + boxes=[boxes_1], + padding=padding_state, + truncation=truncation_state, + ) + self.assertEqual(len(output["input_ids"][0]), model_max_length) + self.assertEqual(len(output["bbox"][0]), model_max_length) + + # Simple with no truncation + # Reset warnings + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer(seq_1, boxes=boxes_1, padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"]), model_max_length) + self.assertNotEqual(len(output["bbox"]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + + tokenizer.deprecation_warnings = {} + with self.assertLogs("transformers", level="WARNING") as cm: + output = tokenizer([seq_1], boxes=[boxes_1], padding=padding_state, truncation=False) + self.assertNotEqual(len(output["input_ids"][0]), model_max_length) + self.assertNotEqual(len(output["bbox"][0]), model_max_length) + self.assertEqual(len(cm.records), 1) + self.assertTrue( + cm.records[0].message.startswith( + "Token indices sequence length is longer than the specified maximum sequence length" + " for this model" + ) + ) + # Check the order of Sequence of input ids, overflowing tokens and bbox sequence with truncation + stride = 2 + information = tokenizer( + seq_0, + boxes=boxes_0, + max_length=total_length - 2, + add_special_tokens=False, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, LayoutLMv3TokenizerFast): + truncated_sequence = information["input_ids"][0] + overflowing_tokens = information["input_ids"][1] + # bbox = information["bbox"][0] + # overflowing_bbox = information["bbox"][1] + self.assertEqual(len(information["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), total_length - 2) + self.assertEqual(truncated_sequence, sequence["input_ids"][:-2]) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :]) + + # self.assertEqual(bbox, sequence["bbox"][:-2]) + # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :]) + else: + truncated_sequence = information["input_ids"] + overflowing_tokens = information["overflowing_tokens"] + # bbox = information["bbox"] + # overflowing_bbox = information["overflowing_token_boxes"] + self.assertEqual(len(truncated_sequence), total_length - 2) + self.assertEqual(truncated_sequence, sequence["input_ids"][:-2]) + + self.assertEqual(len(overflowing_tokens), 2 + stride) + self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :]) + # self.assertEqual(bbox, sequence["bbox"][:-2]) + # self.assertEqual(overflowing_bbox, sequence["bbox"][-(2 + stride) :]) + + @unittest.skip("LayoutLMv3 tokenizer requires boxes besides sequences.") + def test_pretokenized_inputs(self): + pass + + @unittest.skip("LayoutLMv3 tokenizer always expects pretokenized inputs.") + def test_compare_pretokenized_inputs(self): + pass + + @unittest.skip("LayoutLMv3 fast tokenizer does not support prepare_for_model") + def test_compare_prepare_for_model(self): + pass + + @slow + def test_only_label_first_subword(self): + words = ["hello", "niels"] + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + word_labels = [0, 1] + + # test slow tokenizer + tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False) + encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100]) + + tokenizer_p = LayoutLMv3Tokenizer.from_pretrained( + "microsoft/layoutlmv3-base", + only_label_first_subword=False, + add_visual_labels=False, + ) + encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100]) + + # test fast tokenizer + tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False) + encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100]) + + tokenizer_r = LayoutLMv3Tokenizer.from_pretrained( + "microsoft/layoutlmv3-base", + only_label_first_subword=False, + add_visual_labels=False, + ) + encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100]) + + @slow + def test_layoutlmv3_integration_test(self): + + tokenizer_p = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") + tokenizer_r = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base") + + # There are 3 cases: + # CASE 1: document image classification (training + inference), document image token classification (inference), + # in which case only words and normalized bounding boxes are provided to the tokenizer + # CASE 2: document image token classification (training), + # in which case one also provides word labels to the tokenizer + # CASE 3: document image visual question answering (inference), + # in which case one also provides a question to the tokenizer + + # We need to test all 3 cases both on batched and non-batched inputs. + + # CASE 1: not batched + words, boxes = self.get_words_and_boxes() + + # fmt: off + expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 1: batched + words, boxes = self.get_words_and_boxes_batch() + + # fmt: off + expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 2: not batched + words, boxes = self.get_words_and_boxes() + word_labels = [1, 2] + + # fmt: off + expected_results = {'input_ids': [0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'labels': [-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # # CASE 2: batched + words, boxes = self.get_words_and_boxes_batch() + word_labels = [[1, 2], [2, 46]] + + # fmt: off + expected_results = {'input_ids': [[0, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 92, 614, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [961, 885, 992, 912], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'labels': [[-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [-100, 2, 46, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]], 'attention_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # # CASE 3: not batched + question, words, boxes = self.get_question_words_and_boxes() + + # fmt: off + expected_results = {'input_ids': [0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'bbox': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(question, words, boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(question, words, boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # # CASE 3: batched + questions, words, boxes = self.get_question_words_and_boxes_batch() + + # fmt: off + expected_results = {'input_ids': [[0, 99, 18, 39, 766, 116, 2, 2, 795, 13964, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 141, 16, 37, 373, 116, 2, 2, 13964, 795, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'bbox': [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [423, 237, 440, 251], [427, 272, 441, 287], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [256, 38, 330, 58], [256, 38, 330, 58], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(questions, words, boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(questions, words, boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + @unittest.skip("Doesn't support another framework than PyTorch") + def test_np_encode_plus_sent_to_model(self): + pass diff --git a/tests/models/layoutxlm/test_processor_layoutxlm.py b/tests/models/layoutxlm/test_processor_layoutxlm.py index d208097e6c804..69865f6d9f0c6 100644 --- a/tests/models/layoutxlm/test_processor_layoutxlm.py +++ b/tests/models/layoutxlm/test_processor_layoutxlm.py @@ -177,10 +177,11 @@ def test_processor_case_1(self): ) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = " 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -198,10 +199,11 @@ def test_processor_case_1(self): ) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = " 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC’s Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITC’s value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITC’s brands national assets, adding to India’s competitiveness. It is ITC’s aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) @slow @@ -228,7 +230,7 @@ def test_processor_case_2(self): # verify input_ids expected_decoding = " hello world" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -243,7 +245,7 @@ def test_processor_case_2(self): # verify input_ids expected_decoding = " hello world" - decoding = tokenizer.decode(input_processor.input_ids[0].tolist()) + decoding = processor.decode(input_processor.input_ids[0].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -282,7 +284,7 @@ def test_processor_case_3(self): # verify input_ids expected_decoding = " weirdly world" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify labels @@ -304,7 +306,7 @@ def test_processor_case_3(self): # verify input_ids expected_decoding = " my name is niels" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -344,10 +346,11 @@ def test_processor_case_4(self): self.assertListEqual(actual_keys, expected_keys) # verify input_ids + # this was obtained with Tesseract 4.1.1 # fmt: off expected_decoding = " What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 # fmt: on - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -362,8 +365,9 @@ def test_processor_case_4(self): self.assertListEqual(actual_keys, expected_keys) # verify input_ids + # this was obtained with Tesseract 4.1.1 expected_decoding = " what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox @@ -396,7 +400,7 @@ def test_processor_case_5(self): # verify input_ids expected_decoding = " What's his name? hello world" - decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist()) + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) self.assertSequenceEqual(decoding, expected_decoding) # batched @@ -412,11 +416,11 @@ def test_processor_case_5(self): # verify input_ids expected_decoding = " How old is he? hello world" - decoding = tokenizer.decode(input_processor.input_ids[0].tolist()) + decoding = processor.decode(input_processor.input_ids[0].tolist()) self.assertSequenceEqual(decoding, expected_decoding) expected_decoding = " what's the time my name is niels" - decoding = tokenizer.decode(input_processor.input_ids[1].tolist()) + decoding = processor.decode(input_processor.input_ids[1].tolist()) self.assertSequenceEqual(decoding, expected_decoding) # verify bbox diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index ae0f39ac4fe55..3e20dc09e22cd 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -33,6 +33,7 @@ src/transformers/models/gpt2/modeling_gpt2.py src/transformers/models/gptj/modeling_gptj.py src/transformers/models/hubert/modeling_hubert.py src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +src/transformers/models/layoutlmv3/modeling_layoutlmv3.py src/transformers/models/longformer/modeling_longformer.py src/transformers/models/longformer/modeling_tf_longformer.py src/transformers/models/marian/modeling_marian.py