From f989fe65d421e7ea4d1037a00f07eaeee3ad6a29 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 26 Aug 2021 02:19:49 +0200 Subject: [PATCH] Add optional `show` argument to `FreqDist.plot()` (#2786) * Add optional 'show' argument to 'FreqDist.plot()' * Add optional 'show' argument to 'ConditionalFreqDist.plot()' * Fixed small issue in conditions * Added information on *args and **kwargs to plot() methods' docstrings * Added percents option to non-cumulative plot() for FreqDist and ConditionalFreqDist Also modified docstrings accordingly --- nltk/probability.py | 112 ++++++++++++++++++++++++++++---------------- 1 file changed, 72 insertions(+), 40 deletions(-) diff --git a/nltk/probability.py b/nltk/probability.py index f2063a3109..6b6ddae594 100755 --- a/nltk/probability.py +++ b/nltk/probability.py @@ -244,18 +244,25 @@ def max(self): ) return self.most_common(1)[0][0] - def plot(self, *args, **kwargs): + def plot( + self, *args, title="", cumulative=False, percents=False, show=True, **kwargs + ): """ Plot samples from the frequency distribution displaying the most frequent sample first. If an integer parameter is supplied, stop after this many samples have been - plotted. For a cumulative plot, specify cumulative=True. + plotted. For a cumulative plot, specify cumulative=True. Additional + **kwargs are passed to matplotlib's plot function. (Requires Matplotlib to be installed.) - :param title: The title for the graph + :param title: The title for the graph. :type title: str - :param cumulative: A flag to specify whether the plot is cumulative (default = False) - :type title: bool + :param cumulative: Whether the plot is cumulative. (default = False) + :type cumulative: bool + :param percents: Whether the plot uses percents instead of counts. (default = False) + :type percents: bool + :param show: Whether to show the plot, or only return the ax. + :type show: bool """ try: import matplotlib.pyplot as plt @@ -269,27 +276,26 @@ def plot(self, *args, **kwargs): args = [len(self)] samples = [item for item, _ in self.most_common(*args)] - cumulative = _get_kwarg(kwargs, "cumulative", False) - percents = _get_kwarg(kwargs, "percents", False) if cumulative: freqs = list(self._cumulative_frequencies(samples)) - ylabel = "Cumulative Counts" - if percents: - freqs = [f / freqs[len(freqs) - 1] * 100 for f in freqs] - ylabel = "Cumulative Percents" + ylabel = "Cumulative " else: freqs = [self[sample] for sample in samples] - ylabel = "Counts" - # percents = [f * 100 for f in freqs] only in ProbDist? + ylabel = "" + + if percents: + freqs = [f / self.N() * 100 for f in freqs] + ylabel += "Percents" + else: + ylabel += "Counts" ax = plt.gca() ax.grid(True, color="silver") if "linewidth" not in kwargs: kwargs["linewidth"] = 2 - if "title" in kwargs: - ax.set_title(kwargs["title"]) - del kwargs["title"] + if title: + ax.set_title(title) ax.plot(freqs, **kwargs) ax.set_xticks(range(len(samples))) @@ -297,7 +303,8 @@ def plot(self, *args, **kwargs): ax.set_xlabel("Samples") ax.set_ylabel(ylabel) - plt.show() + if show: + plt.show() return ax @@ -1912,18 +1919,35 @@ def N(self): """ return sum(fdist.N() for fdist in self.values()) - def plot(self, *args, **kwargs): + def plot( + self, + *args, + samples=None, + title="", + cumulative=False, + percents=False, + conditions=None, + show=True, + **kwargs, + ): """ Plot the given samples from the conditional frequency distribution. - For a cumulative plot, specify cumulative=True. + For a cumulative plot, specify cumulative=True. Additional *args and + **kwargs are passed to matplotlib's plot function. (Requires Matplotlib to be installed.) :param samples: The samples to plot :type samples: list :param title: The title for the graph :type title: str + :param cumulative: Whether the plot is cumulative. (default = False) + :type cumulative: bool + :param percents: Whether the plot uses percents instead of counts. (default = False) + :type percents: bool :param conditions: The conditions to plot (default is all) :type conditions: list + :param show: Whether to show the plot, or only return the ax. + :type show: bool """ try: import matplotlib.pyplot as plt # import statement fix @@ -1933,34 +1957,40 @@ def plot(self, *args, **kwargs): "See http://matplotlib.org/" ) from e - cumulative = _get_kwarg(kwargs, "cumulative", False) - percents = _get_kwarg(kwargs, "percents", False) - conditions = [ - c for c in _get_kwarg(kwargs, "conditions", self.conditions()) if c in self - ] # conditions should be in self - title = _get_kwarg(kwargs, "title", "") - samples = _get_kwarg( - kwargs, "samples", sorted({v for c in conditions for v in self[c]}) - ) # this computation could be wasted + if not conditions: + conditions = self.conditions() + else: + conditions = [c for c in conditions if c in self] + if not samples: + samples = sorted({v for c in conditions for v in self[c]}) if "linewidth" not in kwargs: kwargs["linewidth"] = 2 ax = plt.gca() - if len(conditions) != 0: + if conditions: freqs = [] for condition in conditions: if cumulative: # freqs should be a list of list where each sub list will be a frequency of a condition - freqs.append(list(self[condition]._cumulative_frequencies(samples))) - ylabel = "Cumulative Counts" - legend_loc = "lower right" - if percents: - freqs[-1] = [f / freqs[len(freqs) - 1] * 100 for f in freqs] - ylabel = "Cumulative Percents" + freq = list(self[condition]._cumulative_frequencies(samples)) else: - freqs.append([self[condition][sample] for sample in samples]) - ylabel = "Counts" - legend_loc = "upper right" - # percents = [f * 100 for f in freqs] only in ConditionalProbDist? + freq = [self[condition][sample] for sample in samples] + + if percents: + freq = [f / self[condition].N() * 100 for f in freq] + + freqs.append(freq) + + if cumulative: + ylabel = "Cumulative " + legend_loc = "lower right" + else: + ylabel = "" + legend_loc = "upper right" + + if percents: + ylabel += "Percents" + else: + ylabel += "Counts" i = 0 for freq in freqs: @@ -1975,7 +2005,9 @@ def plot(self, *args, **kwargs): ax.set_title(title) ax.set_xlabel("Samples") ax.set_ylabel(ylabel) - plt.show() + + if show: + plt.show() return ax