diff --git a/docs/source/en/model_doc/donut.mdx b/docs/source/en/model_doc/donut.mdx index 845d0ac4393646..88cb24b87b0060 100644 --- a/docs/source/en/model_doc/donut.mdx +++ b/docs/source/en/model_doc/donut.mdx @@ -47,20 +47,69 @@ The [`DonutFeatureExtractor`] class is responsible for preprocessing the input i [`DonutProcessor`] wraps [`DonutFeatureExtractor`] and [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] into a single instance to both extract the input features and decode the predicted token ids. +- Step-by-step Document Image Classification + +```py +>>> import re + +>>> from transformers import DonutProcessor, VisionEncoderDecoderModel +>>> from datasets import load_dataset +>>> import torch + +>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-rvlcdip") +>>> model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-base-finetuned-rvlcdip") + +>>> device = "cuda" if torch.cuda.is_available() else "cpu" +>>> model.to(device) # doctest: +IGNORE_RESULT + +>>> # load document image +>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test") +>>> image = dataset[1]["image"] + +>>> # prepare decoder inputs +>>> task_prompt = "" +>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids + +>>> pixel_values = processor(image, return_tensors="pt").pixel_values + +>>> outputs = model.generate( +... pixel_values.to(device), +... decoder_input_ids=decoder_input_ids.to(device), +... max_length=model.decoder.config.max_position_embeddings, +... early_stopping=True, +... pad_token_id=processor.tokenizer.pad_token_id, +... eos_token_id=processor.tokenizer.eos_token_id, +... use_cache=True, +... num_beams=1, +... bad_words_ids=[[processor.tokenizer.unk_token_id]], +... return_dict_in_generate=True, +... ) + +>>> sequence = processor.batch_decode(outputs.sequences)[0] +>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") +>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token +>>> print(sequence) + +``` + +We refer to the example notebooks regarding converting the model output back to JSON. +The code is exactly the same for document parsing, except that the task prompt is different (e.g. ""). +Another example can be found below: + - Step-by-step Document Visual Question Answering (DocVQA) -``` py +```py >>> import re >>> from transformers import DonutProcessor, VisionEncoderDecoderModel >>> from datasets import load_dataset >>> import torch ->>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-docvqa") +>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-docvqa") >>> model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-base-finetuned-docvqa") >>> device = "cuda" if torch.cuda.is_available() else "cpu" ->>> model.to(device) +>>> model.to(device) # doctest: +IGNORE_RESULT >>> # load document image from the DocVQA dataset >>> dataset = load_dataset("hf-internal-testing/example-documents", split="test") @@ -72,18 +121,20 @@ into a single instance to both extract the input features and decode the predict >>> prompt = task_prompt.replace("{user_input}", question) >>> decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids ->>> pixel_values = processor(image, return_tensors="pt").pixel_values - ->>> outputs = model.generate(pixel_values.to(device), -... decoder_input_ids=decoder_input_ids.to(device), -... max_length=model.decoder.config.max_position_embeddings, -... early_stopping=True, -... pad_token_id=processor.tokenizer.pad_token_id, -... eos_token_id=processor.tokenizer.eos_token_id, -... use_cache=True, -... num_beams=1, -... bad_words_ids=[[processor.tokenizer.unk_token_id]], -... return_dict_in_generate=True) +>>> pixel_values = processor(image, return_tensors="pt").pixel_values + +>>> outputs = model.generate( +... pixel_values.to(device), +... decoder_input_ids=decoder_input_ids.to(device), +... max_length=model.decoder.config.max_position_embeddings, +... early_stopping=True, +... pad_token_id=processor.tokenizer.pad_token_id, +... eos_token_id=processor.tokenizer.eos_token_id, +... use_cache=True, +... num_beams=1, +... bad_words_ids=[[processor.tokenizer.unk_token_id]], +... return_dict_in_generate=True, +... ) >>> sequence = processor.batch_decode(outputs.sequences)[0] >>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")