Skip to content

Commit

Permalink
Added unk flag and rel freq weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauwens committed Jan 26, 2024
1 parent d645382 commit 90fdd00
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions nltk/lm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def context_counts(self, context):
self.counts[len(context) + 1][context] if context else self.counts.unigrams
)

def contains_UNKs(self, ngram):
"""Helper method to indicate whether an ngram contains an UNK token or not.
"""
return any([self.counts.unigrams[ng] for ng in ngram])


def entropy(self, text_ngrams):
"""Calculate cross-entropy of model for given evaluation text.
Expand All @@ -183,7 +189,7 @@ def perplexity(self, text_ngrams):
"""
return pow(2.0, self.entropy(text_ngrams))

def entropy_extended(self, text_ngrams, text_fdist, length_normalisation=True, rel_freq_weighting=False):
def entropy_extended(self, text_ngrams, length_normalisation=True, rel_freq_weighting=False):
"""Calculate cross-entropy of model for given evaluation text.
This implementation is based on the standard Shannon entropy,
Expand All @@ -192,16 +198,34 @@ def entropy_extended(self, text_ngrams, text_fdist, length_normalisation=True, r
In case of <UNK> tokens, weight with the minimum relative frequency in the dataset.
:param Iterable(tuple(str)) text_ngrams: A sequence of ngram tuples.
:param FreqDist text_fdist:
:param bool length_normalisation:
:param bool rel_freq_weighting:
:rtype: float
"""
probabilities = [self.score(ngram[-1], ngram[:-1]) for ngram in text_ngrams]
# TODO add function to check for UNKs

if rel_freq_weighting:
# TODO add weighting according to frequency distribution
fdist = FreqDist()

probabilities = []
for ngram in text_ngrams:
probabilities.append(self.score(ngram[-1], ngram[:-1]))
if rel_freq_weighting:
fdist[' '.join(ngram)] += 1

if rel_freq_weighting:
total_freq_fdist = sum(fdist.values())
rel_fdist = {key: fdist[key]/total_freq_fdist for key in fdist.keys()}
min_freq_rel_fdist = min(rel_fdist.values())

weighted_probabilities = []
for prob, ngram in zip(probabilities, text_ngrams):
if contains_UNK(ngram):
prob *= min_freq_rel_fdist
else:
prob *= rel_fdist[ngram]
weighted_probabilities.append(prob)
probabilities = weighted_probabilities

entropy = -1 * sum([prob * log_base2(prob) for prob in probabilities])

Expand Down

0 comments on commit 90fdd00

Please sign in to comment.