Skip to content

Commit

Permalink
Merge pull request #3014 from tomaarsen/perf/tokenize
Browse files Browse the repository at this point in the history
Tackle performance and accuracy regression of sentence tokenizer since NLTK 3.6.6
  • Loading branch information
stevenbird committed Jul 4, 2022
2 parents 6de8254 + 4073f2b commit 86b11fb
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 40 deletions.
30 changes: 30 additions & 0 deletions nltk/test/unit/test_tokenize.py
Expand Up @@ -13,6 +13,7 @@
TreebankWordTokenizer,
TweetTokenizer,
punkt,
sent_tokenize,
word_tokenize,
)

Expand Down Expand Up @@ -809,3 +810,32 @@ class ExtLangVars(punkt.PunktLanguageVars):
)
# The sentence should be split into two sections,
# with one split and hence one decision.

@pytest.mark.parametrize(
"sentences, expected",
[
(
"this is a test. . new sentence.",
["this is a test.", ".", "new sentence."],
),
("This. . . That", ["This.", ".", ".", "That"]),
("This..... That", ["This..... That"]),
("This... That", ["This... That"]),
("This.. . That", ["This.. .", "That"]),
("This. .. That", ["This.", ".. That"]),
("This. ,. That", ["This.", ",.", "That"]),
("This!!! That", ["This!!!", "That"]),
("This! That", ["This!", "That"]),
(
"1. This is R .\n2. This is A .\n3. That's all",
["1.", "This is R .", "2.", "This is A .", "3.", "That's all"],
),
(
"1. This is R .\t2. This is A .\t3. That's all",
["1.", "This is R .", "2.", "This is A .", "3.", "That's all"],
),
("Hello.\tThere", ["Hello.", "There"]),
],
)
def test_sent_tokenize(self, sentences: str, expected: List[str]):
assert sent_tokenize(sentences) == expected
125 changes: 85 additions & 40 deletions nltk/tokenize/punkt.py
Expand Up @@ -7,6 +7,7 @@
# Edward Loper <edloper@gmail.com> (rewrite)
# Joel Nothman <jnothman@student.usyd.edu.au> (almost rewrite)
# Arthur Darcet <arthur@darcet.fr> (fixes)
# Tom Aarsen <> (tackle ReDoS & performance issues)
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

Expand Down Expand Up @@ -106,7 +107,9 @@

import math
import re
import string
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Match, Optional, Tuple, Union

from nltk.probability import FreqDist
from nltk.tokenize.api import TokenizerI
Expand Down Expand Up @@ -578,7 +581,9 @@ def _tokenize_words(self, plaintext):
# { Annotation Procedures
# ////////////////////////////////////////////////////////////

def _annotate_first_pass(self, tokens):
def _annotate_first_pass(
self, tokens: Iterator[PunktToken]
) -> Iterator[PunktToken]:
"""
Perform the first pass of annotation, which makes decisions
based purely based on the word type of each word:
Expand All @@ -599,7 +604,7 @@ def _annotate_first_pass(self, tokens):
self._first_pass_annotation(aug_tok)
yield aug_tok

def _first_pass_annotation(self, aug_tok):
def _first_pass_annotation(self, aug_tok: PunktToken) -> None:
"""
Performs type-based annotation on a single token.
"""
Expand Down Expand Up @@ -1269,13 +1274,13 @@ def train(self, train_text, verbose=False):
# { Tokenization
# ////////////////////////////////////////////////////////////

def tokenize(self, text, realign_boundaries=True):
def tokenize(self, text: str, realign_boundaries: bool = True) -> List[str]:
"""
Given a text, returns a list of the sentences in that text.
"""
return list(self.sentences_from_text(text, realign_boundaries))

def debug_decisions(self, text):
def debug_decisions(self, text: str) -> Iterator[Dict[str, Any]]:
"""
Classifies candidate periods as sentence breaks, yielding a dict for
each that may be used to understand why the decision was made.
Expand Down Expand Up @@ -1311,7 +1316,9 @@ def debug_decisions(self, text):
"break_decision": tokens[0].sentbreak,
}

def span_tokenize(self, text, realign_boundaries=True):
def span_tokenize(
self, text: str, realign_boundaries: bool = True
) -> Iterator[Tuple[int, int]]:
"""
Given a text, generates (start, end) spans of sentences
in the text.
Expand All @@ -1322,7 +1329,9 @@ def span_tokenize(self, text, realign_boundaries=True):
for sentence in slices:
yield (sentence.start, sentence.stop)

def sentences_from_text(self, text, realign_boundaries=True):
def sentences_from_text(
self, text: str, realign_boundaries: bool = True
) -> List[str]:
"""
Given a text, generates the sentences in that text by only
testing candidate sentence breaks. If realign_boundaries is
Expand All @@ -1331,7 +1340,18 @@ def sentences_from_text(self, text, realign_boundaries=True):
"""
return [text[s:e] for s, e in self.span_tokenize(text, realign_boundaries)]

def _match_potential_end_contexts(self, text):
def _get_last_whitespace_index(self, text: str) -> int:
"""
Given a text, find the index of the *last* occurrence of *any*
whitespace character, i.e. " ", "\n", "\t", "\r", etc.
If none is found, return 0.
"""
for i in range(len(text) - 1, -1, -1):
if text[i] in string.whitespace:
return i
return 0

def _match_potential_end_contexts(self, text: str) -> Iterator[Tuple[Match, str]]:
"""
Given a text, find the matches of potential sentence breaks,
alongside the contexts surrounding these sentence breaks.
Expand Down Expand Up @@ -1362,35 +1382,50 @@ def _match_potential_end_contexts(self, text):
>>> pst = PunktSentenceTokenizer()
>>> text = "Very bad acting!!! I promise."
>>> pst._match_potential_end_contexts(text)
>>> list(pst._match_potential_end_contexts(text))
[(<re.Match object; span=(17, 18), match='!'>, 'acting!!! I')]
:param text: String of one or more sentences
:type text: str
:return: List of match-context tuples.
:rtype: List[Tuple[re.Match, str]]
"""
before_words = {}
matches = []
for match in reversed(list(self._lang_vars.period_context_re().finditer(text))):
# Ignore matches that have already been captured by matches to the right of this match
if matches and match.end() > before_start:
continue
# Find the word before the current match
split = text[: match.start()].rsplit(maxsplit=1)
before_start = len(split[0]) if len(split) == 2 else 0
before_words[match] = split[-1] if split else ""
matches.append(match)

return [
(
match,
before_words[match] + match.group() + match.group("after_tok"),
:return: Generator of match-context tuples.
:rtype: Iterator[Tuple[Match, str]]
"""
previous_slice = slice(0, 0)
previous_match = None
for match in self._lang_vars.period_context_re().finditer(text):

# Get the slice of the previous word
before_text = text[previous_slice.stop : match.start()]
last_space_index = self._get_last_whitespace_index(before_text)
if last_space_index:
last_space_index += previous_slice.stop
else:
last_space_index = previous_slice.start
prev_word_slice = slice(last_space_index, match.start())

# If the previous slice does not overlap with this slice, then
# we can yield the previous match and slice. If there is an overlap,
# then we do not yield the previous match and slice.
if previous_match and previous_slice.stop <= prev_word_slice.start:
yield (
previous_match,
text[previous_slice]
+ previous_match.group()
+ previous_match.group("after_tok"),
)
previous_match = match
previous_slice = prev_word_slice

# Yield the last match and context, if it exists
if previous_match:
yield (
previous_match,
text[previous_slice]
+ previous_match.group()
+ previous_match.group("after_tok"),
)
for match in matches[::-1]
]

def _slices_from_text(self, text):
def _slices_from_text(self, text: str) -> Iterator[slice]:
last_break = 0
for match, context in self._match_potential_end_contexts(text):
if self.text_contains_sentbreak(context):
Expand All @@ -1404,7 +1439,9 @@ def _slices_from_text(self, text):
# The last sentence should not contain trailing whitespace.
yield slice(last_break, len(text.rstrip()))

def _realign_boundaries(self, text, slices):
def _realign_boundaries(
self, text: str, slices: Iterator[slice]
) -> Iterator[slice]:
"""
Attempts to realign punctuation that falls after the period but
should otherwise be included in the same sentence.
Expand Down Expand Up @@ -1434,7 +1471,7 @@ def _realign_boundaries(self, text, slices):
if text[sentence1]:
yield sentence1

def text_contains_sentbreak(self, text):
def text_contains_sentbreak(self, text: str) -> bool:
"""
Returns True if the given text includes a sentence break.
"""
Expand All @@ -1446,7 +1483,7 @@ def text_contains_sentbreak(self, text):
found = True
return False

def sentences_from_text_legacy(self, text):
def sentences_from_text_legacy(self, text: str) -> Iterator[str]:
"""
Given a text, generates the sentences in that text. Annotates all
tokens, rather than just those with possible sentence breaks. Should
Expand All @@ -1455,7 +1492,9 @@ def sentences_from_text_legacy(self, text):
tokens = self._annotate_tokens(self._tokenize_words(text))
return self._build_sentence_list(text, tokens)

def sentences_from_tokens(self, tokens):
def sentences_from_tokens(
self, tokens: Iterator[PunktToken]
) -> Iterator[PunktToken]:
"""
Given a sequence of tokens, generates lists of tokens, each list
corresponding to a sentence.
Expand All @@ -1470,7 +1509,7 @@ def sentences_from_tokens(self, tokens):
if sentence:
yield sentence

def _annotate_tokens(self, tokens):
def _annotate_tokens(self, tokens: Iterator[PunktToken]) -> Iterator[PunktToken]:
"""
Given a set of tokens augmented with markers for line-start and
paragraph-start, returns an iterator through those tokens with full
Expand All @@ -1491,7 +1530,9 @@ def _annotate_tokens(self, tokens):

return tokens

def _build_sentence_list(self, text, tokens):
def _build_sentence_list(
self, text: str, tokens: Iterator[PunktToken]
) -> Iterator[str]:
"""
Given the original text and the list of augmented word tokens,
construct and return a tokenized list of sentence strings.
Expand Down Expand Up @@ -1546,7 +1587,7 @@ def _build_sentence_list(self, text, tokens):
yield sentence

# [XX] TESTING
def dump(self, tokens):
def dump(self, tokens: Iterator[PunktToken]) -> None:
print("writing to /tmp/punkt.new...")
with open("/tmp/punkt.new", "w") as outfile:
for aug_tok in tokens:
Expand All @@ -1569,7 +1610,9 @@ def dump(self, tokens):
# { Annotation Procedures
# ////////////////////////////////////////////////////////////

def _annotate_second_pass(self, tokens):
def _annotate_second_pass(
self, tokens: Iterator[PunktToken]
) -> Iterator[PunktToken]:
"""
Performs a token-based classification (section 4) over the given
tokens, making use of the orthographic heuristic (4.1.1), collocation
Expand All @@ -1579,7 +1622,9 @@ def _annotate_second_pass(self, tokens):
self._second_pass_annotation(token1, token2)
yield token1

def _second_pass_annotation(self, aug_tok1, aug_tok2):
def _second_pass_annotation(
self, aug_tok1: PunktToken, aug_tok2: Optional[PunktToken]
) -> Optional[str]:
"""
Performs token-based classification over a pair of contiguous tokens
updating the first.
Expand Down Expand Up @@ -1658,7 +1703,7 @@ def _second_pass_annotation(self, aug_tok1, aug_tok2):

return

def _ortho_heuristic(self, aug_tok):
def _ortho_heuristic(self, aug_tok: PunktToken) -> Union[bool, str]:
"""
Decide whether the given token is the first token in a sentence.
"""
Expand Down

0 comments on commit 86b11fb

Please sign in to comment.