Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayoutLMv3 #17060

Merged
merged 60 commits into from May 24, 2022
Merged

Add LayoutLMv3 #17060

merged 60 commits into from May 24, 2022

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented May 3, 2022

What does this PR do?

This PR implements LayoutLMv3. LayoutLMv3 doesn't require a Detectron2 backbone anymore (yay!).

The PR also includes an example script that can be used to reproduce results of the paper.

Fixes #16914

To do:

  • fix remaining tokenizer tests. These are very black-boxy to me. Pinging @SaulLu here.
  • add model to doc tests
  • remove is_detection logic
  • Make sure the slow tests involving PyTesseract pass
  • Merge add_layoutlmv3_simplify branch

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@SaulLu
Copy link
Contributor

SaulLu commented May 3, 2022

Just for the purpose to keep track of the current status.

As discussed offline I think the next step to "solve" the tokenization tests is to figure out how ["hello", "world"] is tokenized in the original code: is it [0, 42891, 8331, 2] (['<s>', 'Ġhello', 'Ġworld', '</s>']) or [0, 20760, 232, 2] (['<s>', 'hello', "world", '</s>']) or something else ? 😊

@NielsRogge
Copy link
Contributor Author

NielsRogge commented May 3, 2022

As seen here, text is tokenized using RobertaTokenizer, where one provides is_split_into_words=True. Hence, ["hello", "world"] is tokenized as follows:

from transformers import RobertaTokenizer

tokenizer = RobertaTokenizer.from_pretrained("microsoft/layoutlmv3-base")

text = ["hello", "world"]

encoding = tokenizer(text, is_split_into_words=True)

So this results in [0, 20760, 232, 2].

@SaulLu
Copy link
Contributor

SaulLu commented May 5, 2022

Thanks for the clarification! I've opened a PR on your branch (NielsRogge#38) which proposes several changes including 1) changing the default behaviour so that by default a space prefix is added and including all the changes needed to make it work and 2) some small changes to resolve several of the tests that were failing.

I wonder if we shouldn't just remove the option to set add_prefix_space to False because the result will not be satisfactory for decoding and I'm not sure we want to do any fancy tricks to make it "work". (Or at least we should log a message to warn the user that the option is risky).

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work adding this new model!
There are a lot of commented-out code in the test_tokenization file. Not sure if it's too fix or cleanup, but it should be removed before merging the PR. LGTM otherwise!

docs/source/en/_toctree.yml Show resolved Hide resolved
examples/research_projects/layoutlmv3/data_collator.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py Outdated Show resolved Hide resolved
if bidirectional:
num_buckets //= 2
ret += (relative_position > 0).long() * num_buckets
n = torch.abs(relative_position)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love a better name for n

src/transformers/models/layoutlmv3/test.py Outdated Show resolved Hide resolved
@ducviet00
Copy link

Hi @NielsRogge

As the issue #13554 and PR #17092, when input_ids is longer than model's max_length, it would be split into multiple inputs, but pixel_values still has 1 image. Are you going to fix this right now, or next PR?

How to reproduce

from transformers import AutoProcessor, AutoModelForTokenClassification
from datasets import load_dataset
from PIL import Image
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-large")
processor.feature_extractor.apply_ocr = False
model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-large")

words = ['hello' for i in range(1000)]
boxes = [[0, 1, 2, 3] for i in range(1000)]
encoding = processor(
    image, 
    text=words, 
    boxes=boxes,
    truncation=True,
    padding='max_length',
    return_overflowing_tokens=True, 
    return_tensors="pt"
)

print(encoding['input_ids'].shape) # torch.Size([2, 512])
print(encoding['pixel_values'].shape) #torch.Size([1, 3, 224, 224])
overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
model(**encoding) 
# ---> RuntimeError: Sizes of tensors must match except in dimension 1.
# Expected size 4 but got size 1 for tensor number 1 in the list.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very impressive! Have the tokenization and test files been copied from others? I see the # Copied from statements only in the modeling file, it would likely greatly help reviewing if they were also in the other files which have copied parts of the code

truncated_sequence = information_first_truncated["input_ids"][0]
overflowing_tokens = information_first_truncated["input_ids"][1]
bbox = information_first_truncated["bbox"][0]
overflowing_bbox = information_first_truncated["bbox"][0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small note to keep in mind the ongoing discussion we had offline, I'm not sure I understand why the element in position 0 is taken and not the one in position 1. 🙂

@sina-ehsani
Copy link

Thank you so much for your fantastic work. I was wondering if you plan to include the object detection task in LayoutLMv3 as well. I noticed that the PubLayNet fine-tuned model weights have already been uploaded to HuggingFace, but I couldn't find any documentation on this capability in this repository.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive effort! LGTM!

And thanks for adding the # Copied from statements, makes the review easier.

@dcyoung
Copy link

dcyoung commented May 19, 2022

EDIT: Just realized these are the visual tokens... controlled via add_visual_labels

@NielsRogge Thanks for this contribution!
While testing the processor, I'm seeing extra padding on the resultant labels that I did not expect and have not experienced with older versions of layoutlmv2processor.

import numpy as np
from transformers.models.auto.processing_auto import AutoProcessor

processor = AutoProcessor.from_pretrained(
    pretrained_model_name_or_path="microsoft/layoutlmv3-base",
    use_fast=True,
    add_prefix_space=True,
    apply_ocr=False,
)

# not batched
words = ["hello", "world"]
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
word_labels = [1, 2]
image = np.zeros((224, 224, 3), dtype=np.uint8)
results = processor(
    image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt"
)
for k, v in results.items():
    print(k, v.size())

labels = results.labels.squeeze().tolist()
print(labels)

output:

input_ids torch.Size([1, 8])
attention_mask torch.Size([1, 8])
bbox torch.Size([1, 8, 4])
labels torch.Size([1, 205])
pixel_values torch.Size([1, 3, 224, 224])
[-100, 1, -100, -100, -100, -100, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]

This happens beyond maximum seq length as well... where the labels will have a dimension seq_length + ~197
Is this expected?

@NielsRogge
Copy link
Contributor Author

Hi @dcyoung,

thanks for taking a look. Actually you make a great point; I implemented it as the original implementation (where the authors label all visual tokens with -100 and just add a classifier on top of the entire sequence_output), however it makes a lot of sense to just simplify the code in LayoutLMv3ForTokenClassification and not make the processor do this.

Thanks a lot!

@NielsRogge
Copy link
Contributor Author

NielsRogge commented May 20, 2022

And hi @sina-ehsani,

unfortunately I'm (for now) not planning to add the object detection part, because the framework being used (Mask R-CNN) is a ridiculous amount of code and it's not straightforward - for now - to add this to the Transformers library (as there's a "one model, one file" philosophy). So I'd advise to use the original repository for that.

It may be that in the future we add this framework, but I'm actually much more a fan of simpler frameworks like DETR and YOLOS. It would be great if someone fine-tuned a YOLOS model initialized with the weights of the Document Image Transformer (DiT). I feel like you would get the same performance.

@NielsRogge NielsRogge merged commit 31ee80d into huggingface:main May 24, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Make forward pass work

* More improvements

* Remove unused imports

* Remove timm dependency

* Improve loss calculation of token classifier

* Fix most tests

* Add docs

* Add model integration test

* Make all tests pass

* Add LayoutLMv3FeatureExtractor

* Improve integration test + make fixup

* Add example script

* Fix style

* Add LayoutLMv3Processor

* Fix style

* Add option to add visual labels

* Make more tokenizer tests pass

* Fix more tests

* Make more tests pass

* Fix bug and improve docs

* Fix import of processors

* Improve docstrings

* Fix toctree and improve docs

* Fix auto tokenizer

* Move tests to model folder

* Move tests to model folder

* change default behavior add_prefix_space

* add prefix space for fast

* add_prefix_spcae set to True for Fast

* no space before `unique_no_split` token

* add test to hightligh special treatment of added tokens

* fix `test_batch_encode_dynamic_overflowing` by building a long enough example

* fix `test_full_tokenizer` with add_prefix_token

* Fix tokenizer integration test

* Make the code more readable

* Add tests for LayoutLMv3Processor

* Fix style

* Add model to README and update init

* Apply suggestions from code review

* Replace asserts by value errors

* Add suggestion by @ducviet00

* Add model to doc tests

* Simplify script

* Improve README

* a step ahead to fix

* Update pair_input_test

* Make all tokenizer tests pass - phew

* Make style

* Add LayoutLMv3 to CI job

* Fix auto mapping

* Fix CI job name

* Make all processor tests pass

* Make tests of LayoutLMv2 and LayoutXLM consistent

* Add copied from statements to fast tokenizer

* Add copied from statements to slow tokenizer

* Remove add_visual_labels attribute

* Fix tests

* Add link to notebooks

* Improve docs of LayoutLMv3Processor

* Fix reference to section

Co-authored-by: SaulLu <lucilesaul.com@gmail.com>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
@divinit7
Copy link

Thank you so much for adding the model, I had a question on segment position embeddings. How do you create segment position embeddings during inference when the labels are unknown and are just bounding boxes from an ocr. In this notebook the test set also contains segment level bounding box. I have trained a model on segment level embeddings on my use case and it doesn't perform well on token level 2D embeddings during inference.

@jordanparker6
Copy link

jordanparker6 commented Aug 9, 2022

YOLOS

Thanks for the idea. I will have a go at this.

My understanding unilm repo uses Detectron2 (Mask-RCNN) for the backbone of Object Detection in LayoutLMv3 for benchmarking compatibility. Would it be possible to swap out the image backbone for a vision transformer in the LayoutLMv3 training. I saw in the paper:

LayoutLMv3 is the first multimodal model in Document AI that does not rely on a pre-trained CNN or Faster R-CNN backbone to extract visual features, which significantly saves parameters and eliminates region annotations.

My understanding is that LayoutLMv3 is able to generalise better with the unsupervised pre-training over the MIM+MLM+WPA objectives. It also learns correlations between the text / visual inputs that it benefits with on downstream tasks. YOLOS wouldn't include this key text information in document layout anlaysis.

Please correct me if I am wrong... I am learning here.

@jordanparker6
Copy link

@NielsRogge

This thread has lead me to hacking a model that combines the YolosLoss and YolosObjectDetection head with the LayoutLMv3Model to build a LayoutLMv3ObjectDetection prediction head.

Changes to the LayoutLMv3Config and LayoutLMv3FeatureExtractor had to be made to allow for this.

This approach avoids the Mask R-CNN discussed.

Is this something you would be interested in reviewing and integrating if I open a PR?

Or does it deviate too significantly from the research paper?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking
10 participants