Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing 1-length special tokens cut. #13862

Merged
merged 1 commit into from Oct 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/transformers/tokenization_utils.py
Expand Up @@ -20,6 +20,7 @@
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload

from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
Expand Down Expand Up @@ -102,7 +103,6 @@ def split(self, text: str) -> List[str]:
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
"""

# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
Expand All @@ -115,7 +115,7 @@ def split(self, text: str) -> List[str]:
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
states = {}
states = OrderedDict()

# This will contain every indices where we need
# to cut.
Expand Down Expand Up @@ -144,36 +144,36 @@ def split(self, text: str) -> List[str]:

# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if current_char in trie_pointer:
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.

# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index

if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead

# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.

# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current + 1
end = current + 1
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index

if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead

# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True

# Storing back the new pointer into the states.
# Partial matches got longer by one.
Expand All @@ -198,6 +198,18 @@ def split(self, text: str) -> List[str]:
if current_char in self.data:
states[current] = self.data[current_char]

# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break

# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_tokenization_common.py
Expand Up @@ -3562,3 +3562,15 @@ def test_trie_split(self):
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])

def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])

def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])