diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 2f8bce71ff220..78137d2c8a74c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -218,8 +218,6 @@ title: DialoGPT - local: model_doc/distilbert title: DistilBERT - - local: model_doc/donut - title: Donut - local: model_doc/dpr title: DPR - local: model_doc/electra @@ -429,6 +427,8 @@ title: CLIP - local: model_doc/data2vec title: Data2Vec + - local: model_doc/donut + title: Donut - local: model_doc/flava title: FLAVA - local: model_doc/groupvit diff --git a/src/transformers/models/donut/convert_donut_to_pytorch.py b/src/transformers/models/donut/convert_donut_to_pytorch.py index c3eabc83135e7..507f10cb776cf 100644 --- a/src/transformers/models/donut/convert_donut_to_pytorch.py +++ b/src/transformers/models/donut/convert_donut_to_pytorch.py @@ -164,8 +164,15 @@ def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_ task_prompt = task_prompt.replace("{user_input}", question) elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip": task_prompt = "" + elif model_name in [ + "naver-clova-ix/donut-base-finetuned-cord-v1", + "naver-clova-ix/donut-base-finetuned-cord-v1-2560", + ]: + task_prompt = "" elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2": task_prompt = "s_cord-v2>" + elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket": + task_prompt = "" elif model_name in ["naver-clova-ix/donut-proto", "naver-clova-ix/donut-base"]: # use a random prompt task_prompt = "hello world"