Skip to content

Commit

Permalink
Merge pull request #39 from SaulLu/add_layoutlmv3
Browse files Browse the repository at this point in the history
a step ahead to fix `test_maximum_encoding_length_pair_input`
  • Loading branch information
NielsRogge committed May 12, 2022
2 parents 56353ce + e6e25a3 commit 55edf0f
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,25 +1741,24 @@ def test_maximum_encoding_length_pair_input(self):
ids = None

seq0_tokens = tokenizer(seq_0, boxes=boxes_0, add_special_tokens=False)
self.assertGreater(len(seq0_tokens["input_ids"]), 2 + stride)
seq0_input_ids = seq0_tokens["input_ids"]

self.assertGreater(len(seq0_input_ids), 2 + stride)
question_1 = "This is another sentence to be encoded."
seq_1 = ["what", "a", "weird", "test", "weirdly", "weird"]
boxes_1 = [[i, i, i, i] for i in range(len(seq_1))]
boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
if abs(len(seq0_tokens["input_ids"]) - len(seq1_tokens["input_ids"])) <= 2:
if abs(len(seq0_input_ids) - len(seq1_tokens["input_ids"])) <= 2:
seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"]
seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False)
seq_1 = seq_1.split(" ")
boxes_1 = [[i, i, i, i] for i in range(len(seq_1))]
boxes_1 = [[i, i, i, i] for i in range(1, len(seq_1) + 1)]
seq1_tokens = tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)
seq1_input_ids = seq1_tokens["input_ids"]

self.assertGreater(len(seq1_tokens["input_ids"]), 2 + stride)
self.assertGreater(len(seq1_input_ids), 2 + stride)

smallest = (
seq1_tokens["input_ids"]
if len(seq0_tokens["input_ids"]) > len(seq1_tokens["input_ids"])
else seq0_tokens["input_ids"]
)
smallest = seq1_input_ids if len(seq0_input_ids) > len(seq1_input_ids) else seq0_input_ids

# We are not using the special tokens - a bit too hard to test all the tokenizers with this
# TODO try this again later
Expand Down Expand Up @@ -1864,7 +1863,9 @@ def test_maximum_encoding_length_pair_input(self):
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][:-2]
)
truncated_longest_sequence = (
truncated_first_sequence if len(seq0_tokens) > len(seq1_tokens) else truncated_second_sequence
truncated_first_sequence
if len(seq0_input_ids) > len(seq1_input_ids)
else truncated_second_sequence
)

overflow_first_sequence = (
Expand All @@ -1876,24 +1877,24 @@ def test_maximum_encoding_length_pair_input(self):
+ tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["input_ids"][-(2 + stride) :]
)
overflow_longest_sequence = (
overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence
overflow_first_sequence if len(seq0_input_ids) > len(seq1_input_ids) else overflow_second_sequence
)

bbox_first = [[0, 0, 0, 0]] * (len(seq_0) - 2)
bbox_first = [[0, 0, 0, 0]] * (len(seq0_input_ids) - 2)
bbox_first_sequence = bbox_first + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"]
overflowing_token_bbox_first_sequence_slow = [[0, 0, 0, 0]] * (2 + stride)
overflowing_token_bbox_first_sequence_fast = [[0, 0, 0, 0]] * (2 + stride) + tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"]

bbox_second = [[0, 0, 0, 0]] * len(seq_0)
bbox_second = [[0, 0, 0, 0]] * len(seq0_input_ids)
bbox_second_sequence = (
bbox_second + tokenizer(seq_1, boxes=boxes_1, add_special_tokens=False)["bbox"][:-2]
)
overflowing_token_bbox_second_sequence_slow = tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"][-(2 + stride) :]
overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq_0) + tokenizer(
overflowing_token_bbox_second_sequence_fast = [[0, 0, 0, 0]] * len(seq0_input_ids) + tokenizer(
seq_1, boxes=boxes_1, add_special_tokens=False
)["bbox"][-(2 + stride) :]

Expand Down Expand Up @@ -2028,7 +2029,7 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_first_sequence)

self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_tokens["input_ids"]))
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_input_ids))
self.assertEqual(overflowing_tokens, overflow_first_sequence)
self.assertEqual(bbox, bbox_first_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_fast)
Expand All @@ -2042,7 +2043,7 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(truncated_sequence, truncated_first_sequence)

self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, seq0_tokens["input_ids"][-(2 + stride) :])
self.assertEqual(overflowing_tokens, seq0_input_ids[-(2 + stride) :])
self.assertEqual(bbox, bbox_first_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_first_sequence_slow)

Expand All @@ -2069,7 +2070,7 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
self.assertEqual(truncated_sequence, truncated_second_sequence)

self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_tokens["input_ids"]))
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_input_ids))
self.assertEqual(overflowing_tokens, overflow_second_sequence)
self.assertEqual(bbox, bbox_second_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_fast)
Expand All @@ -2083,7 +2084,7 @@ def test_maximum_encoding_length_pair_input(self):
self.assertEqual(truncated_sequence, truncated_second_sequence)

self.assertEqual(len(overflowing_tokens), 2 + stride)
self.assertEqual(overflowing_tokens, seq1_tokens["input_ids"][-(2 + stride) :])
self.assertEqual(overflowing_tokens, seq1_input_ids[-(2 + stride) :])
self.assertEqual(bbox, bbox_second_sequence)
self.assertEqual(overflowing_bbox, overflowing_token_bbox_second_sequence_slow)

Expand Down

0 comments on commit 55edf0f

Please sign in to comment.