Skip to content

Commit

Permalink
Add optional show argument to FreqDist.plot() (#2786)
Browse files Browse the repository at this point in the history
* Add optional 'show' argument to 'FreqDist.plot()'

* Add optional 'show' argument to 'ConditionalFreqDist.plot()'

* Fixed small issue in conditions

* Added information on *args and **kwargs to plot() methods' docstrings

* Added percents option to non-cumulative plot() for FreqDist and ConditionalFreqDist

Also modified docstrings accordingly
  • Loading branch information
tomaarsen committed Aug 26, 2021
1 parent 193cf48 commit f989fe6
Showing 1 changed file with 72 additions and 40 deletions.
112 changes: 72 additions & 40 deletions nltk/probability.py
Expand Up @@ -244,18 +244,25 @@ def max(self):
)
return self.most_common(1)[0][0]

def plot(self, *args, **kwargs):
def plot(
self, *args, title="", cumulative=False, percents=False, show=True, **kwargs
):
"""
Plot samples from the frequency distribution
displaying the most frequent sample first. If an integer
parameter is supplied, stop after this many samples have been
plotted. For a cumulative plot, specify cumulative=True.
plotted. For a cumulative plot, specify cumulative=True. Additional
**kwargs are passed to matplotlib's plot function.
(Requires Matplotlib to be installed.)
:param title: The title for the graph
:param title: The title for the graph.
:type title: str
:param cumulative: A flag to specify whether the plot is cumulative (default = False)
:type title: bool
:param cumulative: Whether the plot is cumulative. (default = False)
:type cumulative: bool
:param percents: Whether the plot uses percents instead of counts. (default = False)
:type percents: bool
:param show: Whether to show the plot, or only return the ax.
:type show: bool
"""
try:
import matplotlib.pyplot as plt
Expand All @@ -269,35 +276,35 @@ def plot(self, *args, **kwargs):
args = [len(self)]
samples = [item for item, _ in self.most_common(*args)]

cumulative = _get_kwarg(kwargs, "cumulative", False)
percents = _get_kwarg(kwargs, "percents", False)
if cumulative:
freqs = list(self._cumulative_frequencies(samples))
ylabel = "Cumulative Counts"
if percents:
freqs = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
ylabel = "Cumulative "
else:
freqs = [self[sample] for sample in samples]
ylabel = "Counts"
# percents = [f * 100 for f in freqs] only in ProbDist?
ylabel = ""

if percents:
freqs = [f / self.N() * 100 for f in freqs]
ylabel += "Percents"
else:
ylabel += "Counts"

ax = plt.gca()
ax.grid(True, color="silver")

if "linewidth" not in kwargs:
kwargs["linewidth"] = 2
if "title" in kwargs:
ax.set_title(kwargs["title"])
del kwargs["title"]
if title:
ax.set_title(title)

ax.plot(freqs, **kwargs)
ax.set_xticks(range(len(samples)))
ax.set_xticklabels([str(s) for s in samples], rotation=90)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)

plt.show()
if show:
plt.show()

return ax

Expand Down Expand Up @@ -1912,18 +1919,35 @@ def N(self):
"""
return sum(fdist.N() for fdist in self.values())

def plot(self, *args, **kwargs):
def plot(
self,
*args,
samples=None,
title="",
cumulative=False,
percents=False,
conditions=None,
show=True,
**kwargs,
):
"""
Plot the given samples from the conditional frequency distribution.
For a cumulative plot, specify cumulative=True.
For a cumulative plot, specify cumulative=True. Additional *args and
**kwargs are passed to matplotlib's plot function.
(Requires Matplotlib to be installed.)
:param samples: The samples to plot
:type samples: list
:param title: The title for the graph
:type title: str
:param cumulative: Whether the plot is cumulative. (default = False)
:type cumulative: bool
:param percents: Whether the plot uses percents instead of counts. (default = False)
:type percents: bool
:param conditions: The conditions to plot (default is all)
:type conditions: list
:param show: Whether to show the plot, or only return the ax.
:type show: bool
"""
try:
import matplotlib.pyplot as plt # import statement fix
Expand All @@ -1933,34 +1957,40 @@ def plot(self, *args, **kwargs):
"See http://matplotlib.org/"
) from e

cumulative = _get_kwarg(kwargs, "cumulative", False)
percents = _get_kwarg(kwargs, "percents", False)
conditions = [
c for c in _get_kwarg(kwargs, "conditions", self.conditions()) if c in self
] # conditions should be in self
title = _get_kwarg(kwargs, "title", "")
samples = _get_kwarg(
kwargs, "samples", sorted({v for c in conditions for v in self[c]})
) # this computation could be wasted
if not conditions:
conditions = self.conditions()
else:
conditions = [c for c in conditions if c in self]
if not samples:
samples = sorted({v for c in conditions for v in self[c]})
if "linewidth" not in kwargs:
kwargs["linewidth"] = 2
ax = plt.gca()
if len(conditions) != 0:
if conditions:
freqs = []
for condition in conditions:
if cumulative:
# freqs should be a list of list where each sub list will be a frequency of a condition
freqs.append(list(self[condition]._cumulative_frequencies(samples)))
ylabel = "Cumulative Counts"
legend_loc = "lower right"
if percents:
freqs[-1] = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
freq = list(self[condition]._cumulative_frequencies(samples))
else:
freqs.append([self[condition][sample] for sample in samples])
ylabel = "Counts"
legend_loc = "upper right"
# percents = [f * 100 for f in freqs] only in ConditionalProbDist?
freq = [self[condition][sample] for sample in samples]

if percents:
freq = [f / self[condition].N() * 100 for f in freq]

freqs.append(freq)

if cumulative:
ylabel = "Cumulative "
legend_loc = "lower right"
else:
ylabel = ""
legend_loc = "upper right"

if percents:
ylabel += "Percents"
else:
ylabel += "Counts"

i = 0
for freq in freqs:
Expand All @@ -1975,7 +2005,9 @@ def plot(self, *args, **kwargs):
ax.set_title(title)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)
plt.show()

if show:
plt.show()

return ax

Expand Down

0 comments on commit f989fe6

Please sign in to comment.