Skip to content

Commit

Permalink
Merge pull request #2319 from BLaZeKiLL/BLaZeKiLL-polt-bug-fix
Browse files Browse the repository at this point in the history
Plot function bugs fixed
  • Loading branch information
stevenbird committed Jul 4, 2019
2 parents f6a4f38 + ecdcc57 commit 8bcc98a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
3 changes: 2 additions & 1 deletion AUTHORS.md
Expand Up @@ -252,7 +252,8 @@
- Osman Zubair <https://github.com/okz12>
- Viresh Gupta <https://github.com/virresh>
- Ondřej Cífka <https://github.com/cifkao>
- Iris X. Zhou <https://github.com/irisxzhou>
- Iris X. Zhou <https://github.com/irisxzhou>
- Devashish Lal <https://github.com/BLaZeKiLL>


## Others whose work we've taken and included in NLTK, but who didn't directly contribute it:
Expand Down
66 changes: 36 additions & 30 deletions nltk/probability.py
Expand Up @@ -1904,7 +1904,7 @@ def plot(self, *args, **kwargs):
:type conditions: list
"""
try:
from matplotlib import plt
import matplotlib.pyplot as plt #import statment fix
except ImportError:
raise ValueError(
'The plot function requires matplotlib to be installed.'
Expand All @@ -1913,40 +1913,46 @@ def plot(self, *args, **kwargs):

cumulative = _get_kwarg(kwargs, 'cumulative', False)
percents = _get_kwarg(kwargs, 'percents', False)
conditions = _get_kwarg(kwargs, 'conditions', sorted(self.conditions()))
conditions = [c for c in _get_kwarg(kwargs, 'conditions', self.conditions()) if c in self] # conditions should be in self
title = _get_kwarg(kwargs, 'title', '')
samples = _get_kwarg(
kwargs, 'samples', sorted(set(v for c in conditions
if v in self
for v in self[c]))
kwargs, 'samples', sorted(set(v
for c in conditions
for v in self[c]))
) # this computation could be wasted
if "linewidth" not in kwargs:
kwargs["linewidth"] = 2

for condition in conditions:
if cumulative:
freqs = list(self[condition]._cumulative_frequencies(samples))
ylabel = "Cumulative Counts"
legend_loc = 'lower right'
if percents:
freqs = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
else:
freqs = [self[condition][sample] for sample in samples]
ylabel = "Counts"
legend_loc = 'upper right'
# percents = [f * 100 for f in freqs] only in ConditionalProbDist?
kwargs['label'] = "%s" % condition
ax.plot(freqs, *args, **kwargs)


ax.legend(loc=legend_loc)
ax.grid(True, color="silver")
ax.set_xticks(range(len(samples)), [text_type(s) for s in samples], rotation=90)
if title:
ax.set_title(title)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)
ax = plt.gca()
if (len(conditions) != 0):
freqs = []
for condition in conditions:
if cumulative:
# freqs should be a list of list where each sub list will be a frequency of a condition
freqs.append(list(self[condition]._cumulative_frequencies(samples)))
ylabel = "Cumulative Counts"
legend_loc = 'lower right'
if percents:
freqs[-1] = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
else:
freqs.append([self[condition][sample] for sample in samples])
ylabel = "Counts"
legend_loc = 'upper right'
# percents = [f * 100 for f in freqs] only in ConditionalProbDist?

i = 0
for freq in freqs:
kwargs['label'] = conditions[i] #label for each condition
i += 1
ax.plot(freq, *args, **kwargs)
ax.legend(loc=legend_loc)
ax.grid(True, color="silver")
ax.set_xticks(range(len(samples)))
ax.set_xticklabels([text_type(s) for s in samples], rotation=90)
if title:
ax.set_title(title)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)
plt.show()

return ax
Expand Down
2 changes: 1 addition & 1 deletion nltk/test/unit/test_cfd_mutation.py
Expand Up @@ -16,7 +16,7 @@ def test_plot(self):
empty = ConditionalFreqDist()
self.assertEqual(empty.conditions(),[])
try:
empty.plot(conditions="BUG") # nonexistent keys shouldn't be added
empty.plot(conditions=["BUG"]) # nonexistent keys shouldn't be added
except:
pass
self.assertEqual(empty.conditions(),[])
Expand Down

0 comments on commit 8bcc98a

Please sign in to comment.