Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added multi Bleu functionality and tests #2793

Merged
merged 6 commits into from Nov 20, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion nltk/test/bleu.doctest
Expand Up @@ -9,7 +9,7 @@ If the candidate has no alignment to any of the references, the BLEU score is 0.
>>> bleu(
... ['The candidate has no alignment to any of the references'.split()],
... 'John loves Mary'.split(),
... [1],
... (1,),
... )
0

Expand Down
130 changes: 127 additions & 3 deletions nltk/test/unit/translate/test_bleu.py
Expand Up @@ -120,7 +120,7 @@ def test_zero_matches(self):

# Test BLEU to nth order of n-grams, where n is len(hypothesis).
for n in range(1, len(hypothesis)):
weights = [1.0 / n] * n # Uniform weights.
weights = (1.0 / n,) * n # Uniform weights.
assert sentence_bleu(references, hypothesis, weights) == 0

def test_full_matches(self):
Expand All @@ -130,7 +130,7 @@ def test_full_matches(self):

# Test BLEU to nth order of n-grams, where n is len(hypothesis).
for n in range(1, len(hypothesis)):
weights = [1.0 / n] * n # Uniform weights.
weights = (1.0 / n,) * n # Uniform weights.
assert sentence_bleu(references, hypothesis, weights) == 1.0

def test_partial_matches_hypothesis_longer_than_reference(self):
Expand All @@ -153,7 +153,7 @@ def test_case_where_n_is_bigger_than_hypothesis_length(self):
references = ["John loves Mary ?".split()]
hypothesis = "John loves Mary".split()
n = len(hypothesis) + 1 #
weights = [1.0 / n] * n # Uniform weights.
weights = (1.0 / n,) * n # Uniform weights.
# Since no n-grams matches were found the result should be zero
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0
self.assertAlmostEqual(
Expand Down Expand Up @@ -269,3 +269,127 @@ def test_corpus_bleu_with_bad_sentence(self):
)
except AttributeError: # unittest.TestCase.assertWarns is only supported in Python >= 3.2.
self.assertAlmostEqual(corpus_bleu(references, hypotheses), 0.0, places=4)


class TestBLEUWithMultipleWeights(unittest.TestCase):
def test_corpus_bleu_with_multiple_weights(self):
hyp1 = [
"It",
"is",
"a",
"guide",
"to",
"action",
"which",
"ensures",
"that",
"the",
"military",
"always",
"obeys",
"the",
"commands",
"of",
"the",
"party",
]
ref1a = [
"It",
"is",
"a",
"guide",
"to",
"action",
"that",
"ensures",
"that",
"the",
"military",
"will",
"forever",
"heed",
"Party",
"commands",
]
ref1b = [
"It",
"is",
"the",
"guiding",
"principle",
"which",
"guarantees",
"the",
"military",
"forces",
"always",
"being",
"under",
"the",
"command",
"of",
"the",
"Party",
]
ref1c = [
"It",
"is",
"the",
"practical",
"guide",
"for",
"the",
"army",
"always",
"to",
"heed",
"the",
"directions",
"of",
"the",
"party",
]
hyp2 = [
"he",
"read",
"the",
"book",
"because",
"he",
"was",
"interested",
"in",
"world",
"history",
]
ref2a = [
"he",
"was",
"interested",
"in",
"world",
"history",
"because",
"he",
"read",
"the",
"book",
]
weight_1 = (1, 0, 0, 0)
weight_2 = (0.25, 0.25, 0.25, 0.25)
weight_3 = (0, 0, 0, 0, 1)

bleu_scores = corpus_bleu(
list_of_references=[[ref1a, ref1b, ref1c], [ref2a]],
hypotheses=[hyp1, hyp2],
weights=[weight_1, weight_2, weight_3],
)
assert bleu_scores[0] == corpus_bleu(
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_1
)
assert bleu_scores[1] == corpus_bleu(
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_2
)
assert bleu_scores[2] == corpus_bleu(
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_3
)
42 changes: 29 additions & 13 deletions nltk/translate/bleu_score.py
Expand Up @@ -150,8 +150,8 @@ def corpus_bleu(
:type list_of_references: list(list(list(str)))
:param hypotheses: a list of hypothesis sentences
:type hypotheses: list(list(str))
:param weights: weights for unigrams, bigrams, trigrams and so on
:type weights: list(float)
:param weights: weights for unigrams, bigrams, trigrams and so on, (one or list of weights)
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
:type weights: tuple(float) or list(tuple(float))
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
:param smoothing_function:
:type smoothing_function: SmoothingFunction
:param auto_reweigh: Option to re-normalize the weights uniformly.
Expand All @@ -169,11 +169,17 @@ def corpus_bleu(
"The number of hypotheses and their reference(s) should be the " "same "
)

if weights and isinstance(weights[0], float):
weights = [weights]
elif isinstance(weights, tuple):
weights = [weights]
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
max_weight_length = max(len(weight) for weight in weights)

# Iterate through each hypothesis and their corresponding references.
for references, hypothesis in zip(list_of_references, hypotheses):
# For each order of ngram, calculate the numerator and
# denominator for the corpus-level modified precision.
for i, _ in enumerate(weights, start=1):
for i in range(1, max_weight_length + 1):
p_i = modified_precision(references, hypothesis, i)
p_numerators[i] += p_i.numerator
p_denominators[i] += p_i.denominator
Expand All @@ -187,23 +193,23 @@ def corpus_bleu(
# Calculate corpus-level brevity penalty.
bp = brevity_penalty(ref_lengths, hyp_lengths)

# Uniformly re-weighting based on maximum hypothesis lengths if largest
# order of n-grams < 4 and weights is set at default.
if auto_reweigh:
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
weights = (1 / hyp_lengths,) * hyp_lengths
# # Uniformly re-weighting based on maximum hypothesis lengths if largest
# # order of n-grams < 4 and weights is set at default.
# if auto_reweigh:
# if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
# weights = (1 / hyp_lengths,) * hyp_lengths
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved

# Collects the various precision values for the different ngram orders.
p_n = [
Fraction(p_numerators[i], p_denominators[i], _normalize=False)
for i, _ in enumerate(weights, start=1)
for i in range(1, max_weight_length + 1)
]

# Returns 0 if there's no matching n-grams
# We only need to check for p_numerators[1] == 0, since if there's
# no unigrams, there won't be any higher order ngrams.
if p_numerators[1] == 0:
return 0
return 0 if len(weights) == 1 else [0] * len(weights)

# If there's no smoothing, set use method0 from SmoothinFunction class.
if not smoothing_function:
Expand All @@ -215,9 +221,19 @@ def corpus_bleu(
p_n = smoothing_function(
p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
)
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
s = bp * math.exp(math.fsum(s))
return s

bleu_scores = []
for weight in weights:
# Uniformly re-weighting based on maximum hypothesis lengths if largest
# order of n-grams < 4 and weights is set at default.
if auto_reweigh:
if hyp_lengths < 4 and weight == (0.25, 0.25, 0.25, 0.25):
weight = (1 / hyp_lengths,) * hyp_lengths

s = (w_i * math.log(p_i) for w_i, p_i in zip(weight, p_n))
s = bp * math.exp(math.fsum(s))
bleu_scores.append(s)
return bleu_scores[0] if len(weights) == 1 else bleu_scores


def modified_precision(references, hypothesis, n):
Expand Down