diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index 1b00d894bd087..da9e89c1d8b63 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -108,10 +108,13 @@ def as_target_processor(self): self.current_processor = self.feature_extractor self._in_target_context_manager = False - def token2json(self, tokens, is_inner_value=False): + def token2json(self, tokens, is_inner_value=False, added_vocab=None): """ Convert a (generated) token sequence into an ordered JSON format. """ + if added_vocab is None: + added_vocab = self.tokenizer.get_added_vocab() + output = dict() while tokens: @@ -131,7 +134,7 @@ def token2json(self, tokens, is_inner_value=False): if content is not None: content = content.group(1).strip() if r""): leaf = leaf.strip() - if leaf in self.tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>": + if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>": leaf = leaf[1:-2] # for categorical special tokens output[key].append(leaf) if len(output[key]) == 1: @@ -148,7 +151,7 @@ def token2json(self, tokens, is_inner_value=False): tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() if tokens[:6] == r"": # non-leaf nodes - return [output] + self.token2json(tokens[6:], is_inner_value=True) + return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab) if len(output): return [output] if is_inner_value else output diff --git a/tests/models/donut/test_processing_donut.py b/tests/models/donut/test_processing_donut.py new file mode 100644 index 0000000000000..cad0e37bc5195 --- /dev/null +++ b/tests/models/donut/test_processing_donut.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import DonutProcessor + + +DONUT_PRETRAINED_MODEL_NAME = "naver-clova-ix/donut-base" + + +class DonutProcessorTest(unittest.TestCase): + def setUp(self): + self.processor = DonutProcessor.from_pretrained(DONUT_PRETRAINED_MODEL_NAME) + + def test_token2json(self): + expected_json = { + "name": "John Doe", + "age": "99", + "city": "Atlanta", + "state": "GA", + "zip": "30301", + "phone": "123-4567", + "nicknames": [{"nickname": "Johnny"}, {"nickname": "JD"}], + } + + sequence = ( + "John Doe99Atlanta" + "GA30301123-4567" + "Johnny" + "JD" + ) + actual_json = self.processor.token2json(sequence) + + self.assertDictEqual(actual_json, expected_json)