Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot function bugs fixed #2319

Merged
merged 5 commits into from Jul 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion AUTHORS.md
Expand Up @@ -252,7 +252,8 @@
- Osman Zubair <https://github.com/okz12>
- Viresh Gupta <https://github.com/virresh>
- Ondřej Cífka <https://github.com/cifkao>
- Iris X. Zhou <https://github.com/irisxzhou>
- Iris X. Zhou <https://github.com/irisxzhou>
- Devashish Lal <https://github.com/BLaZeKiLL>


## Others whose work we've taken and included in NLTK, but who didn't directly contribute it:
Expand Down
66 changes: 36 additions & 30 deletions nltk/probability.py
Expand Up @@ -1904,7 +1904,7 @@ def plot(self, *args, **kwargs):
:type conditions: list
"""
try:
from matplotlib import plt
import matplotlib.pyplot as plt #import statment fix
except ImportError:
raise ValueError(
'The plot function requires matplotlib to be installed.'
Expand All @@ -1913,40 +1913,46 @@ def plot(self, *args, **kwargs):

cumulative = _get_kwarg(kwargs, 'cumulative', False)
percents = _get_kwarg(kwargs, 'percents', False)
conditions = _get_kwarg(kwargs, 'conditions', sorted(self.conditions()))
conditions = [c for c in _get_kwarg(kwargs, 'conditions', self.conditions()) if c in self] # conditions should be in self
title = _get_kwarg(kwargs, 'title', '')
samples = _get_kwarg(
kwargs, 'samples', sorted(set(v for c in conditions
if v in self
for v in self[c]))
kwargs, 'samples', sorted(set(v
for c in conditions
for v in self[c]))
) # this computation could be wasted
if "linewidth" not in kwargs:
kwargs["linewidth"] = 2

for condition in conditions:
if cumulative:
freqs = list(self[condition]._cumulative_frequencies(samples))
ylabel = "Cumulative Counts"
legend_loc = 'lower right'
if percents:
freqs = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
else:
freqs = [self[condition][sample] for sample in samples]
ylabel = "Counts"
legend_loc = 'upper right'
# percents = [f * 100 for f in freqs] only in ConditionalProbDist?
kwargs['label'] = "%s" % condition
ax.plot(freqs, *args, **kwargs)


ax.legend(loc=legend_loc)
ax.grid(True, color="silver")
ax.set_xticks(range(len(samples)), [text_type(s) for s in samples], rotation=90)
if title:
ax.set_title(title)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)
ax = plt.gca()
if (len(conditions) != 0):
freqs = []
for condition in conditions:
if cumulative:
# freqs should be a list of list where each sub list will be a frequency of a condition
freqs.append(list(self[condition]._cumulative_frequencies(samples)))
ylabel = "Cumulative Counts"
legend_loc = 'lower right'
if percents:
freqs[-1] = [f / freqs[len(freqs) - 1] * 100 for f in freqs]
ylabel = "Cumulative Percents"
else:
freqs.append([self[condition][sample] for sample in samples])
ylabel = "Counts"
legend_loc = 'upper right'
# percents = [f * 100 for f in freqs] only in ConditionalProbDist?

i = 0
for freq in freqs:
kwargs['label'] = conditions[i] #label for each condition
i += 1
ax.plot(freq, *args, **kwargs)
ax.legend(loc=legend_loc)
ax.grid(True, color="silver")
ax.set_xticks(range(len(samples)))
ax.set_xticklabels([text_type(s) for s in samples], rotation=90)
if title:
ax.set_title(title)
ax.set_xlabel("Samples")
ax.set_ylabel(ylabel)
plt.show()

return ax
Expand Down
2 changes: 1 addition & 1 deletion nltk/test/unit/test_cfd_mutation.py
Expand Up @@ -16,7 +16,7 @@ def test_plot(self):
empty = ConditionalFreqDist()
self.assertEqual(empty.conditions(),[])
try:
empty.plot(conditions="BUG") # nonexistent keys shouldn't be added
empty.plot(conditions=["BUG"]) # nonexistent keys shouldn't be added
except:
pass
self.assertEqual(empty.conditions(),[])
Expand Down