Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reference to entropy implementation used #3229

Merged
merged 9 commits into from
Jan 29, 2024
39 changes: 39 additions & 0 deletions nltk/lm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from bisect import bisect
from itertools import accumulate

from nltk import FreqDist
from nltk.lm.counter import NgramCounter
from nltk.lm.util import log_base2
from nltk.lm.vocabulary import Vocabulary
Expand Down Expand Up @@ -163,6 +164,9 @@ def context_counts(self, context):
def entropy(self, text_ngrams):
"""Calculate cross-entropy of model for given evaluation text.

This implementation is based on the Shannon-McMillan-Breiman theorem,
mbauwens marked this conversation as resolved.
Show resolved Hide resolved
as used and referenced by Dan Jurafsky and Jordan Boyd-Graber.

:param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
:rtype: float

Expand All @@ -179,6 +183,41 @@ def perplexity(self, text_ngrams):
"""
return pow(2.0, self.entropy(text_ngrams))

def entropy_extended(self, text_ngrams, text_fdist, length_normalisation=True, rel_freq_weighting=False):
iliakur marked this conversation as resolved.
Show resolved Hide resolved
"""Calculate cross-entropy of model for given evaluation text.

This implementation is based on the standard Shannon entropy,
extended with the possibility to normalise the entropy by sentence length,
and/or weight the output by the relative frequency of the ngram.
In case of <UNK> tokens, weight with the minimum relative frequency in the dataset.

:param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
:param FreqDist text_fdist:
:param bool length_normalisation:
:param bool rel_freq_weighting:
:rtype: float

"""
probabilities = [self.score(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
# TODO add function to check for UNKs
if rel_freq_weighting:
# TODO add weighting according to frequency distribution

entropy = -1 * sum([prob * log_base2(prob) for prob in probabilities])

if length_normalisation:
entropy /= len(probabilities)

return entropy_extended

def perplexity_extended(self, text_ngrams, text_fdist, normalised=True, rel_freq_weighted=False):
"""Calculates the perplexity of the given text based on the extended version of the entropy method.

This is simply 2 ** cross-entropy for the text, so the arguments are the same.

"""
return pow(2.0, self.entropy_extended(text_ngrams, text_fdist, normalised, rel_freq_weighted))

def generate(self, num_words=1, text_seed=None, random_seed=None):
"""Generate words from the model.

Expand Down