Skip to content

Commit

Permalink
Fix code examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Aug 8, 2022
1 parent 2a30ae5 commit a1092ed
Showing 1 changed file with 66 additions and 15 deletions.
81 changes: 66 additions & 15 deletions docs/source/en/model_doc/donut.mdx
Expand Up @@ -47,20 +47,69 @@ The [`DonutFeatureExtractor`] class is responsible for preprocessing the input i
[`DonutProcessor`] wraps [`DonutFeatureExtractor`] and [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]
into a single instance to both extract the input features and decode the predicted token ids.

- Step-by-step Document Image Classification

```py
>>> import re

>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
>>> from datasets import load_dataset
>>> import torch

>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-rvlcdip")
>>> model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-base-finetuned-rvlcdip")

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device) # doctest: +IGNORE_RESULT

>>> # load document image
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
>>> image = dataset[1]["image"]

>>> # prepare decoder inputs
>>> task_prompt = "<s_rvlcdip>"
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

>>> pixel_values = processor(image, return_tensors="pt").pixel_values

>>> outputs = model.generate(
... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True,
... )

>>> sequence = processor.batch_decode(outputs.sequences)[0]
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
>>> print(sequence)
<s_class><advertisement/></s_class>
```

We refer to the example notebooks regarding converting the model output back to JSON.
The code is exactly the same for document parsing, except that the task prompt is different (e.g. "<s_cord-v2>").
Another example can be found below:

- Step-by-step Document Visual Question Answering (DocVQA)

``` py
```py
>>> import re

>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
>>> from datasets import load_dataset
>>> import torch

>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-docvqa")
>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-docvqa")
>>> model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-base-finetuned-docvqa")

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device)
>>> model.to(device) # doctest: +IGNORE_RESULT

>>> # load document image from the DocVQA dataset
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
Expand All @@ -72,18 +121,20 @@ into a single instance to both extract the input features and decode the predict
>>> prompt = task_prompt.replace("{user_input}", question)
>>> decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids

>>> pixel_values = processor(image, return_tensors="pt").pixel_values

>>> outputs = model.generate(pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True)
>>> pixel_values = processor(image, return_tensors="pt").pixel_values

>>> outputs = model.generate(
... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True,
... )

>>> sequence = processor.batch_decode(outputs.sequences)[0]
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
Expand Down

0 comments on commit a1092ed

Please sign in to comment.