diff --git a/AUTHORS.md b/AUTHORS.md index eef7f03b25..23e177a8a0 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -252,7 +252,8 @@ - Osman Zubair - Viresh Gupta - Ondřej Cífka -- Iris X. Zhou +- Iris X. Zhou +- Devashish Lal ## Others whose work we've taken and included in NLTK, but who didn't directly contribute it: diff --git a/nltk/probability.py b/nltk/probability.py index 80ee1180de..9dc110c0a2 100755 --- a/nltk/probability.py +++ b/nltk/probability.py @@ -1904,7 +1904,7 @@ def plot(self, *args, **kwargs): :type conditions: list """ try: - from matplotlib import plt + import matplotlib.pyplot as plt #import statment fix except ImportError: raise ValueError( 'The plot function requires matplotlib to be installed.' @@ -1913,40 +1913,46 @@ def plot(self, *args, **kwargs): cumulative = _get_kwarg(kwargs, 'cumulative', False) percents = _get_kwarg(kwargs, 'percents', False) - conditions = _get_kwarg(kwargs, 'conditions', sorted(self.conditions())) + conditions = [c for c in _get_kwarg(kwargs, 'conditions', self.conditions()) if c in self] # conditions should be in self title = _get_kwarg(kwargs, 'title', '') samples = _get_kwarg( - kwargs, 'samples', sorted(set(v for c in conditions - if v in self - for v in self[c])) + kwargs, 'samples', sorted(set(v + for c in conditions + for v in self[c])) ) # this computation could be wasted if "linewidth" not in kwargs: kwargs["linewidth"] = 2 - - for condition in conditions: - if cumulative: - freqs = list(self[condition]._cumulative_frequencies(samples)) - ylabel = "Cumulative Counts" - legend_loc = 'lower right' - if percents: - freqs = [f / freqs[len(freqs) - 1] * 100 for f in freqs] - ylabel = "Cumulative Percents" - else: - freqs = [self[condition][sample] for sample in samples] - ylabel = "Counts" - legend_loc = 'upper right' - # percents = [f * 100 for f in freqs] only in ConditionalProbDist? - kwargs['label'] = "%s" % condition - ax.plot(freqs, *args, **kwargs) - - - ax.legend(loc=legend_loc) - ax.grid(True, color="silver") - ax.set_xticks(range(len(samples)), [text_type(s) for s in samples], rotation=90) - if title: - ax.set_title(title) - ax.set_xlabel("Samples") - ax.set_ylabel(ylabel) + ax = plt.gca() + if (len(conditions) != 0): + freqs = [] + for condition in conditions: + if cumulative: + # freqs should be a list of list where each sub list will be a frequency of a condition + freqs.append(list(self[condition]._cumulative_frequencies(samples))) + ylabel = "Cumulative Counts" + legend_loc = 'lower right' + if percents: + freqs[-1] = [f / freqs[len(freqs) - 1] * 100 for f in freqs] + ylabel = "Cumulative Percents" + else: + freqs.append([self[condition][sample] for sample in samples]) + ylabel = "Counts" + legend_loc = 'upper right' + # percents = [f * 100 for f in freqs] only in ConditionalProbDist? + + i = 0 + for freq in freqs: + kwargs['label'] = conditions[i] #label for each condition + i += 1 + ax.plot(freq, *args, **kwargs) + ax.legend(loc=legend_loc) + ax.grid(True, color="silver") + ax.set_xticks(range(len(samples))) + ax.set_xticklabels([text_type(s) for s in samples], rotation=90) + if title: + ax.set_title(title) + ax.set_xlabel("Samples") + ax.set_ylabel(ylabel) plt.show() return ax diff --git a/nltk/test/unit/test_cfd_mutation.py b/nltk/test/unit/test_cfd_mutation.py index befa1fc1df..7e21d7e88a 100644 --- a/nltk/test/unit/test_cfd_mutation.py +++ b/nltk/test/unit/test_cfd_mutation.py @@ -16,7 +16,7 @@ def test_plot(self): empty = ConditionalFreqDist() self.assertEqual(empty.conditions(),[]) try: - empty.plot(conditions="BUG") # nonexistent keys shouldn't be added + empty.plot(conditions=["BUG"]) # nonexistent keys shouldn't be added except: pass self.assertEqual(empty.conditions(),[])