diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py
index 1b00d894bd087..da9e89c1d8b63 100644
--- a/src/transformers/models/donut/processing_donut.py
+++ b/src/transformers/models/donut/processing_donut.py
@@ -108,10 +108,13 @@ def as_target_processor(self):
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
- def token2json(self, tokens, is_inner_value=False):
+ def token2json(self, tokens, is_inner_value=False, added_vocab=None):
"""
Convert a (generated) token sequence into an ordered JSON format.
"""
+ if added_vocab is None:
+ added_vocab = self.tokenizer.get_added_vocab()
+
output = dict()
while tokens:
@@ -131,7 +134,7 @@ def token2json(self, tokens, is_inner_value=False):
if content is not None:
content = content.group(1).strip()
if r""):
leaf = leaf.strip()
- if leaf in self.tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>":
+ if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
leaf = leaf[1:-2] # for categorical special tokens
output[key].append(leaf)
if len(output[key]) == 1:
@@ -148,7 +151,7 @@ def token2json(self, tokens, is_inner_value=False):
tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
if tokens[:6] == r"": # non-leaf nodes
- return [output] + self.token2json(tokens[6:], is_inner_value=True)
+ return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
if len(output):
return [output] if is_inner_value else output
diff --git a/tests/models/donut/test_processing_donut.py b/tests/models/donut/test_processing_donut.py
new file mode 100644
index 0000000000000..cad0e37bc5195
--- /dev/null
+++ b/tests/models/donut/test_processing_donut.py
@@ -0,0 +1,48 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+from transformers import DonutProcessor
+
+
+DONUT_PRETRAINED_MODEL_NAME = "naver-clova-ix/donut-base"
+
+
+class DonutProcessorTest(unittest.TestCase):
+ def setUp(self):
+ self.processor = DonutProcessor.from_pretrained(DONUT_PRETRAINED_MODEL_NAME)
+
+ def test_token2json(self):
+ expected_json = {
+ "name": "John Doe",
+ "age": "99",
+ "city": "Atlanta",
+ "state": "GA",
+ "zip": "30301",
+ "phone": "123-4567",
+ "nicknames": [{"nickname": "Johnny"}, {"nickname": "JD"}],
+ }
+
+ sequence = (
+ "John Doe99Atlanta"
+ "GA30301123-4567"
+ "Johnny"
+ "JD"
+ )
+ actual_json = self.processor.token2json(sequence)
+
+ self.assertDictEqual(actual_json, expected_json)