Skip to content

Commit

Permalink
Fixing 1-length special tokens cut. (huggingface#13862)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and LysandreJik committed Oct 6, 2021
1 parent d7db364 commit c2901b0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
70 changes: 41 additions & 29 deletions src/transformers/tokenization_utils.py
Expand Up @@ -20,6 +20,7 @@
import itertools
import re
import unicodedata
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union, overload

from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
Expand Down Expand Up @@ -102,7 +103,6 @@ def split(self, text: str) -> List[str]:
>>> trie.split("[CLS] This is a extra_id_100")
["[CLS]", " This is a ", "extra_id_100"]
"""

# indexes are counted left of the chars index.
# "hello", index 0, is left of h, index 1 is between h and e.
# index 5 is right of the "o".
Expand All @@ -115,7 +115,7 @@ def split(self, text: str) -> List[str]:
# If the trie contains, "blowing", and "lower" and we encounter the
# string "blower", we need to split into ["b", "lower"].
# This is where we need to keep track of multiple possible starts.
states = {}
states = OrderedDict()

# This will contain every indices where we need
# to cut.
Expand Down Expand Up @@ -144,36 +144,36 @@ def split(self, text: str) -> List[str]:

# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if current_char in trie_pointer:
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.

# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index

if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead

# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.

# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
lookahead_index = current + 1
end = current + 1
next_char = text[lookahead_index] if lookahead_index < len(text) else None
while next_char in trie_pointer:
trie_pointer = trie_pointer[next_char]
lookahead_index += 1
if "" in trie_pointer:
end = lookahead_index
skip = lookahead_index

if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead

# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True

# Storing back the new pointer into the states.
# Partial matches got longer by one.
Expand All @@ -198,6 +198,18 @@ def split(self, text: str) -> List[str]:
if current_char in self.data:
states[current] = self.data[current_char]

# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break

# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_tokenization_common.py
Expand Up @@ -3562,3 +3562,15 @@ def test_trie_split(self):
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])

def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])

def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])

0 comments on commit c2901b0

Please sign in to comment.