Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LayoutLMv2 + LayoutXLM #12604

Merged
merged 114 commits into from Aug 30, 2021

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Jul 9, 2021

What does this PR do?

This PR adds Microsoft's LayoutLMv2 and LayoutXLM models, in PyTorch. The latter is a multilingual version of LayoutLMv2. For now, I have not yet added any documentation related to LayoutXLM, I'm not sure whether we need a new model directory + documentation page for that one, since one can load a LayoutXLM model like so:
model = LayoutLMv2Model.from_pretrained("microsoft/layoutxlm-base").

LayoutLMv2 is an improvement of LayoutLM (improves SOTA across several benchmarks, including new ones), by incorporating visual, text and layout information to understand scanned documents. Detectron2 is used for its visual backbone (which is a ResNeXt-FPN).

The original repo only has LayoutLMv2Model and LayoutLMv2ForTokenClassification. However, in the paper they also use the model to classify document images (on RVL-CDIP), and perform visual question answering (on DocVQA). Therefore, I've added LayoutLMv2ForSequenceClassification and LayoutLMv2ForQuestionAnswering. I've modelled them like they were described in the paper, but there's no official implementation to be found.

Fixes #11932 #12194

Who can review?

@LysandreJik @sgugger

To do:

  • fix tests (there's still one test failing, namely test_initialization) => Lysandre would be great if you can help me fix that one. It has to do with one of the layers of the backbone. Integration test is also added.
  • install Detectron2 + pytesseract to run all tests on CircleCI.
  • perhaps define custom ModelOutputs, as the length of the hidden states and attentions is actually seq_length + config.image_feature_pool_shape[0] * config.image_feature_pool_shape[1] instead of just seq_length-> update: will add a comment to the "Tips" section in the documentation instead.
  • write documentation about LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer and LayoutLMv2Processor
  • make some more demo notebooks.

Notes:

  • I know some variable names could maybe be named better (like for example rel_pos_bias in the configuration). However, if we update the names, then people will not longer be able to easily convert models from the original repo to HuggingFace and vice versa. The authors did use HuggingFace for their entire codebase (they used Transformers, the Trainer, Datasets,...). The model is already uploaded by the authors on the hub.
  • There is still some code included in the modeling file for distributed training, namely to convert to SyncBatchNorm instead of BatchNorm when distributed training is available. I guess these are to be removed? UPDATE: moved to separate method.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this model!
For LayoutXLM, I don't think we need a new page if we can use the same architecture and tokenizer without changes. Just mention on the doc page the architecture does both.

Don't forget to add the model to the main README!

src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/detectron2_config.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work implementing this @NielsRogge, and thank you for implementing the integration tests.

The docs are very understandable, great work. If you have some notebooks available, it would be great to put them in the documentation as well.

src/transformers/models/layoutlmv2/detectron2_config.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
@jasonkit
Copy link

jasonkit commented Jul 16, 2021

Thanks a lot for adding this model!
For LayoutXLM, I don't think we need a new page if we can use the same architecture and tokenizer without changes. Just mention on the doc page the architecture does both.

Don't forget to add the model to the main README!
@sgugger

Just want to point out that LayoutLMv2's tokenizer is subclass of BertTokenizer , while LayoutXLM's tokenizer is subclass on XLMRobertaTokenizer (and this make LayoutLMv2 cross-lingual)

As far as I know, this is the only difference between LayoutLMv2 and LayoutXLM's

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Jul 16, 2021

@jasonkit thanks for pointing that out, I will create a separate LayoutXLMTokenizer which inherits from XLMRobertaTokenizer.

@sgugger
Copy link
Collaborator

sgugger commented Jul 16, 2021

Note that is the tokenizer is the same as a XLMRobertaTokenizer, you don't need to create a new class, you can just the set the right tokenizer_class in the config.

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Jul 16, 2021

Hmm ok, I see that this wasn't done for LayoutLMTokenizer, which was created, but is actually just BertTokenizer. Can you point to an example where this was done?

@sgugger
Copy link
Collaborator

sgugger commented Jul 16, 2021

Sure: there is BigBirdPegasus for instance that uses the same tokenizer as BigBird: here is an example of config file for a checkpoint of BigBirdPegasus that sets the tokenizer class.

@Jordy-VL
Copy link

Can't wait to test this ;) Thanks for the community effort!

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Jul 27, 2021

@sgugger after internal discussion, I have created a new LayoutLMv2Processor. A Processor combines a FeatureExtractor (which handles the image-related stuff) and a Tokenizer (which handles the text-related stuff). So this is ideal for multi-modal models. Processors have previously been defined for Wav2Vec2 and CLIP.

However, there's a difference between the processors defined for Wav2Vec2/CLIP and the one for LayoutLMv2. The former processors can either be a feature extractor or tokenizer at one particular moment (they are just a wrapper around both). The processor for LayoutLMv2 on the other hand applies both in a sequence, since it first uses the feature extractor to apply OCR on the document images to get words + bounding boxes, which are then provided to the tokenizer, which converts them to token-level input_ids, attention_mask, token_type_ids and bbox. By combining the feature extractor and the tokenizer, the processor really does everything for the user: you just give it a document image as input, and the inputs required for the model come out. Also note that one can initialize the feature extractor with either apply_ocr to True or False, depending on whether the user wants to apply OCR himself on the document images, or whether he wants to use PyTesseract (which the feature extractor uses by default). For now, there are 5 different use cases for the processor, see the integration tests in test_processor_layoutlmv2.py to see them all.

Also, an additional feature (which I think people will like), is that one can optionally also provide word-level labels to the processor, and these will then automatically be converted to token-level labels. You could see it a bit as if tokenize_and_align function is incorporated into the processor (actually in the tokenizer - but I assume people could just use the processor).

Happy to get your review :) as you will see, LayoutLMv2FeatureExtractor is fairly minimal, it does two things: 1) resize images to 224x224 and optionally, 2) apply OCR to get words + boxes. LayoutLMv2Tokenizer is a bit more extensive (it also handles padding/truncation of token-level bounding boxes etc.). Finally, LayoutLMv2Processor makes everything more simple by just having one front-facing API.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design with the feature extractor looks great to me!

src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py Outdated Show resolved Hide resolved
tests/test_feature_extraction_layoutlmv2.py Outdated Show resolved Hide resolved
tests/test_feature_extraction_layoutlmv2.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look, the processor/feature extractor/tokenizer approach looks good to me. Let me know when you're happy with the final state and I'll play with it to test out the API deeper.

Impressive test suite!

@dcyoung
Copy link

dcyoung commented Aug 6, 2021

@NielsRogge from what I can tell, the fast tokenizer is no longer supported in this PR. When using the existing impl of LayoutLMv2Tokenizer in the context of token classification/sequence labeling, I've been following the original repos arguments:

      padding="max_length",
      pad_to_multiple_of=8,
      max_length=512,
      truncation=True,
      return_overflowing_tokens=True,
      is_split_into_words=True,

as a means of creating multiple sequences from longer input samples. I believe return_overflowing_tokens is unsupported by the tokenizer in this PR without a Fast implementation. Is there a different way to achieve multiple sequences per input sample with the new tokenizer?

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Aug 6, 2021

Hi @dcyoung,

I'm currently working on implementing a fast tokenizer, but the slow tokenizer supports the return_overflowing_tokens argument.

The API of the tokenizer is a bit more extensive for LayoutLMv2. You can pass a list of words and corresponding (normalized) boxes, and the tokenizer will automatically turn everything into token-level input_ids, attention_mask, token_type_ids and bbox. It will also pad/truncate boxes if you specify the relevant arguments. Small example:

from transformers import LayoutLMv2Tokenizer

tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")

words = ["hello", "world"]
boxes = [[1,2,3,4], [5,6,7,8]]

encoded_inputs = tokenizer(words, boxes=boxes, return_tensors="pt")

Can you try it out? It will also return overflowing token boxes if you want it to.

@dcyoung
Copy link

dcyoung commented Aug 6, 2021

Can you try it out? It will also return overflowing token boxes if you want it to.

Yup. That works fine for me. Though, I'm wondering about trying to create batches of sequences from a single "long" input sample which overflows the 512 token limit. This is for SER tasks where I'd like to consider every token on a document, requiring splitting the original sequence into multiple 512 token sequences. Previously, the tokenize_and_align_labels and DataCollatorForKeyValueExtraction implementations accomplished this behavior. I'm curious how best to achieve the same behavior using this new setup.

    tokenizer = LayoutLMv2Tokenizer.from_pretrained(
        "microsoft/layoutlmv2-base-uncased",
    )

    n = 2000
    words = n * ["hello"]
    boxes = n * [[1, 2, 3, 4]]

    encoded_inputs = tokenizer(
        words,
        boxes=boxes,
        padding="max_length",
        pad_to_multiple_of=8,
        max_length=512,
        truncation=True,
        return_overflowing_tokens=True,
        is_split_into_words=True,
        return_tensors="pt",
    )
    print(encoded_inputs.keys())
    for k, v in encoded_inputs.items():
        print(k, v.size())
dict_keys(['overflowing_tokens', 'overflowing_token_boxes', 'num_truncated_tokens', 'input_ids', 'bbox', 'token_type_ids', 'attention_mask'])
overflowing_tokens torch.Size([1, 1490])
overflowing_token_boxes torch.Size([1, 1490, 4])
num_truncated_tokens torch.Size([1])
input_ids torch.Size([1, 512])
bbox torch.Size([1, 512, 4])
token_type_ids torch.Size([1, 512])
attention_mask torch.Size([1, 512])

I see now from the outputs above, that the tokenizer does return overflow tokens. However, I don't see the overflow_to_sample_mapping KVP which was previously used by tokenize_and_align_labels. Does the current tokenizer support this behavior atm? If so, what arguments yield this batching behavior? And if not do you have a suggestion on the easiest way of achieving something similar?

Would this require splitting the overflowing_tokens and overflowing_token_boxes into new sequences and manually adding the special tokens, as well as pad the last sample < 512 tokens? Or alternatively, tokenizing without truncation... and use a data collator which splits, and pads?

@dcyoung
Copy link

dcyoung commented Aug 10, 2021

@NielsRogge I took a pass at batching the overflow tokens. In the Processor, i added some logic to modify the encoded_inputs like so:

class LayoutLMv2Processor:
    ...

    def prepare_overflow(self, encoded_inputs: BatchEncoding) -> List[BatchEncoding]:
        num_truncated_tokens = max(
            0, int(encoded_inputs.get("num_truncated_tokens", [0])[0])
        )
        max_source_tokens_per_sample = 510
        num_extra_samples = ceil(num_truncated_tokens / max_source_tokens_per_sample)
        extra_encoded_inputs = []
        for i in range(num_extra_samples):
            start_idx = i * max_source_tokens_per_sample
            tokens = encoded_inputs["overflowing_tokens"][0][
                start_idx : start_idx + max_source_tokens_per_sample
            ].tolist()
            boxes = encoded_inputs["overflowing_token_boxes"][0][
                start_idx : start_idx + max_source_tokens_per_sample
            ].tolist()
            labels = encoded_inputs["overflowing_labels"][0][
                start_idx : start_idx + max_source_tokens_per_sample
            ].tolist()
            seq_len = len(tokens)

            padded = self.tokenizer._pad(
                encoded_inputs={
                    "input_ids": [101] + tokens + [102],
                    "bbox": [[0, 0, 0, 0]] + boxes + [[1000, 1000, 1000, 1000]],
                    "token_type_ids": (2 + seq_len) * [0],
                    "labels": [-100] + labels + [-100],
                    "attention_mask": (2 + seq_len) * [1],
                },
                max_length=512,
                padding_strategy=PaddingStrategy.MAX_LENGTH,
                pad_to_multiple_of=8,
                return_attention_mask=True,
            )
            extra_encoded_inputs.append(
                {
                    "image": torch.clone(encoded_inputs["image"]),
                    **{k: torch.tensor(v).unsqueeze(0) for k, v in padded.items()},
                }
            )

        return extra_encoded_inputs

However, this required adding an additional overflowing_labels during tokenization similar to the current calculation of overflowing_token_boxes or overflowing_tokens. This is a small change but easier accomplished in the tokenizer source than after the fact.

Using this processor, i am able to generate batches of sequences from a long input sequence. While I haven't had a chance to thoroughly test, I am able to run this batch through the model just fine to produce corresponding logits. Ex:

encoded_inputs= processor(
    img,
    words,
    boxes=bboxes,
    word_labels=word_label_ids,
    return_tensors="pt",
    padding="max_length",
    pad_to_multiple_of=8,
    max_length=512,
    truncation=True,
    return_overflowing_tokens=True,
    is_split_into_words=True,
    batch_overflow=True,
)
extra_encoded_inputs = processor.prepare_overflow(encoded_inputs)
for model_inputs in [encoded_inputs] + extra_encoded_inputs:
    outputs = model(**model_inputs)
    print("Predicted Logits: ", outputs.logits.size())

Does this seem like a reasonable approach, and if so... would it be possible to add the overflow_labels changes to the tokenizer? Perhaps you can think of a better abstraction for batching process within the tokenizer itself?

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Aug 13, 2021

I see now from the outputs above, that the tokenizer does return overflow tokens. However, I don't see the overflow_to_sample_mapping KVP which was previously used by tokenize_and_align_labels. Does the current tokenizer support this behavior atm? If so, what arguments yield this batching behavior? And if not do you have a suggestion on the easiest way of achieving something similar?

The overflow_to_sample_mapping is something that is only supported by fast tokenizers. I'm currently working on LayoutLMv2TokenizerFast. I'll merge it with this branch once it's ready. Thanks for your feedback!

Are you planning to add the LayoutLMv2/XLMForRelationExtraction models that we can find in the original repo?

Yes, but perhaps in a future PR, because it's not clear to me how they use the model at inference time.

If you have other questions, can you please post them elsewhere instead of on this thread? Just to keep this PR a bit clean :) perhaps we can set up a Slack channel to discuss this model. If you can give me your email address, I'll set it up.

Thanks!

@lvaleriu
Copy link

I see now from the outputs above, that the tokenizer does return overflow tokens. However, I don't see the overflow_to_sample_mapping KVP which was previously used by tokenize_and_align_labels. Does the current tokenizer support this behavior atm? If so, what arguments yield this batching behavior? And if not do you have a suggestion on the easiest way of achieving something similar?

The overflow_to_sample_mapping is something that is only supported by fast tokenizers. I'm currently working on LayoutLMv2TokenizerFast. I'll merge it with this branch once it's ready. Thanks for your feedback!

Are you planning to add the LayoutLMv2/XLMForRelationExtraction models that we can find in the original repo?

Yes, but perhaps in a future PR, because it's not clear to me how they use the model at inference time.

If you have other questions, can you please post them elsewhere instead of on this thread? Just to keep this PR a bit clean :) perhaps we can set up a Slack channel to discuss this model. If you can give me your email address, I'll set it up.

Thanks!

You're right about redirecting me to a dedicated channel. Here is my email: lacatusu.valeriu@gmail.com.

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LayoutLMv2 Model
9 participants