Skip to content

Commit

Permalink
Merge pull request #2 from larsmans/pberkes-mldata
Browse files Browse the repository at this point in the history
Improved error handling
  • Loading branch information
pberkes committed May 30, 2011
2 parents 6e46e2f + d3bfaa5 commit df1fda0
Show file tree
Hide file tree
Showing 16 changed files with 536 additions and 121 deletions.
10 changes: 5 additions & 5 deletions benchmarks/bench_sgd_covertype.py
Expand Up @@ -17,7 +17,7 @@
Classifier train-time test-time error-rate
--------------------------------------------
Liblinear 9.4471s 0.0184s 0.2305
GNB 2.5426s 0.1725s 0.3633
GaussianNB 2.5426s 0.1725s 0.3633
SGD 0.2137s 0.0047s 0.2300
Expand Down Expand Up @@ -57,7 +57,7 @@

from scikits.learn.svm import LinearSVC
from scikits.learn.linear_model import SGDClassifier
from scikits.learn.naive_bayes import GNB
from scikits.learn.naive_bayes import GaussianNB
from scikits.learn import metrics

######################################################################
Expand Down Expand Up @@ -158,8 +158,8 @@ def benchmark(clf):
liblinear_err, liblinear_train_time, liblinear_test_time = liblinear_res

######################################################################
## Train GNB model
gnb_err, gnb_train_time, gnb_test_time = benchmark(GNB())
## Train GaussianNB model
gnb_err, gnb_train_time, gnb_test_time = benchmark(GaussianNB())

######################################################################
## Train SGD model
Expand Down Expand Up @@ -189,7 +189,7 @@ def print_row(clf_type, train_time, test_time, err):
print("-" * 44)
print_row("Liblinear", liblinear_train_time, liblinear_test_time,
liblinear_err)
print_row("GNB", gnb_train_time, gnb_test_time, gnb_err)
print_row("GaussianNB", gnb_train_time, gnb_test_time, gnb_err)
print_row("SGD", sgd_train_time, sgd_test_time, sgd_err)
print("")
print("")
3 changes: 2 additions & 1 deletion doc/modules/classes.rst
Expand Up @@ -148,7 +148,8 @@ Naive Bayes
:toctree: generated/
:template: class.rst

naive_bayes.GNB
naive_bayes.GaussianNB
naive_bayes.MultinomialNB


Nearest Neighbors
Expand Down
83 changes: 70 additions & 13 deletions doc/modules/naive_bayes.rst
Expand Up @@ -6,31 +6,88 @@ Naive Bayes


**Naive Bayes** algorithms are a set of supervised learning methods
based on applying Baye's theorem with strong (naive) independence
assumptions.
based on applying Bayes' theorem with the "naive" assumption of independence
between every pair of features. Given a class variable :math:`c` and a
dependent set of feature variables :math:`f_1` through :math:`f_n`, Bayes'
theorem states the following relationship:

The advantage of Naive Bayes approaches are:
.. math::
- It requires a small amount of training data to estimate the
parameters necessary for classification.
p(c \mid f_1,\dots,f_n) \propto p(c) p(\mid f_1,\dots,f_n \mid c)
- In spite of their naive design and apparently over-simplified
assumptions, naive Bayes classifiers have worked quite well in
many complex real-world situations.
Using the naive assumption this relationship is simplified:

- The decoupling of the class conditional feature distributions
means that each distribution can be independently estimated as a
one dimensional distribution. This in turn helps to alleviate
problems stemming from the curse of dimensionality.
.. math::
p(c \mid f_1,\dots,f_n) \propto p(c) \prod_{i=1}^{n} p(f_i \mid c)
\Downarrow
\hat{c} = \arg\max_c p(c) \prod_{i=1}^{n} p(f_i \mid c),
where we used the Maximum a Posteriori estimator.

The different naive Bayes classifiers differ by the assumption on the
distribution of :math:`p(f_i \mid c)`:

In spite of their naive design and apparently over-simplified assumptions,
naive Bayes classifiers have worked quite well in many real-world situations,
famously document classification and spam filtering. They requires a small
amount of training data to estimate the necessary parameters.

The decoupling of the class conditional feature distributions means that each
distribution can be independently estimated as a one dimensional distribution.
This in turn helps to alleviate problems stemming from the curse of
dimensionality.


Gaussian Naive Bayes
--------------------

:class:`GNB` implements the Gaussian Naive Bayes algorithm for classification.
:class:`GaussianNB` implements the Gaussian Naive Bayes algorithm for
classification. The likelihood of the features is assumed to be gaussian:

.. math::
p(f_i \mid c) &= \frac{1}{\sqrt{2\pi\sigma^2_c}} \exp^{-\frac{ (f_i - \mu_c)^2}{2\pi\sigma^2_c}}
The parameters of the distribution, :math:`\sigma_c` and :math:`\mu_c` are
estimated using maximum likelihood.

.. topic:: Examples:

* :ref:`example_naive_bayes.py`,

Multinomial Naive Bayes
-----------------------

:class:`MultinomialNB` implements the Multinomial Naive Bayes algorithm for classification.
Multinomial Naive Bayes models the distribution of words in a document as a
multinomial. The distribution is parametrized by the vector
:math:`\overline{\theta_c} = (\theta_{c1},\ldots,\theta_{cn})` where :math:`c`
is the class of document, :math:`n` is the size of the vocabulary and :math:`\theta_{ci}`
is the probability of word :math:`i` appearing in a document of class :math:`c`.
The likelihood of document :math:`d` is,

.. math::
p(d \mid \overline{\theta_c}) &= \frac{ (\sum_i f_i)! }{\prod_i f_i !} \prod_i(\theta_{ci})^{f_i}
where :math:`f_{i}` is the frequency count of word :math:`i`. It can be shown
that the maximum posterior probability is,

.. math::
\hat{c} = \arg\max_c [ \log p(\overline{\theta_c}) + \sum_i f_i \log \theta_{ci} ]
The vector of parameters :math:`\overline{\theta_c}` is estimated by a smoothed
version of maximum likelihood,

.. math::
\hat{\theta}_{ci} = \frac{ N_{ci} + \alpha_i }{N_c + \alpha }
where :math:`N_{ci}` is the number of times word :math:`i` appears in a document
of class :math:`c` and :math:`N_{c}` is the total count of words in a document
of class :math:`c`. The smoothness priors :math:`\alpha_i` and their sum
:math:`\alpha` account for words not seen in the learning samples.
6 changes: 3 additions & 3 deletions examples/covariance/plot_lw_vs_oas.py
Expand Up @@ -48,12 +48,12 @@

lw = LedoitWolf(store_precision=False)
lw.fit(X, assume_centered=True)
lw_mse[i,j] = lw.error(real_cov)
lw_mse[i,j] = lw.error_norm(real_cov, scaling=False)
lw_shrinkage[i,j] = lw.shrinkage_

oa = OAS(store_precision=False)
oa.fit(X, assume_centered=True)
oa_mse[i,j] = oa.error(real_cov)
oa_mse[i,j] = oa.error_norm(real_cov, scaling=False)
oa_shrinkage[i,j] = oa.shrinkage_

# plot MSE
Expand All @@ -62,7 +62,7 @@
label='Ledoit-Wolf', color='g')
pl.errorbar(n_samples_range, oa_mse.mean(1), yerr=oa_mse.std(1),
label='OAS', color='r')
pl.ylabel("MSE")
pl.ylabel("Squared error")
pl.legend(loc="upper right")
pl.title("Comparison of covariance estimators")
pl.xlim(5, 31)
Expand Down
13 changes: 10 additions & 3 deletions examples/document_classification_20newsgroups.py
Expand Up @@ -45,6 +45,7 @@
from scikits.learn.linear_model import RidgeClassifier
from scikits.learn.svm.sparse import LinearSVC
from scikits.learn.linear_model.sparse import SGDClassifier
from scikits.learn.naive_bayes import MultinomialNB
from scikits.learn import metrics


Expand Down Expand Up @@ -128,9 +129,10 @@ def benchmark(clf):
score = metrics.f1_score(y_test, pred)
print "f1-score: %0.3f" % score

nnz = clf.coef_.nonzero()[0].shape[0]
print "non-zero coef: %d" % nnz
print
if hasattr(clf, 'coef_'):
nnz = clf.coef_.nonzero()[0].shape[0]
print "non-zero coef: %d" % nnz
print

if print_report:
print "classification report:"
Expand Down Expand Up @@ -165,3 +167,8 @@ def benchmark(clf):
print "Elastic-Net penalty"
sgd_results = benchmark(SGDClassifier(alpha=.0001, n_iter=50,
penalty="elasticnet"))

# Train sparse MultinomialNB
print 80 * '='
print "MultinomialNB penalty"
mnnb_results = benchmark(MultinomialNB(alpha=.01))
8 changes: 4 additions & 4 deletions examples/naive_bayes.py → examples/gaussian_naive_bayes.py
Expand Up @@ -3,7 +3,7 @@
Gaussian Naive Bayes
============================
A classification example using Gaussian Naive Bayes (GNB).
A classification example using Gaussian Naive Bayes (GaussianNB).
"""

Expand All @@ -18,9 +18,9 @@
y = iris.target

################################################################################
# GNB
from scikits.learn.naive_bayes import GNB
gnb = GNB()
# GaussianNB
from scikits.learn.naive_bayes import GaussianNB
gnb = GaussianNB()

y_pred = gnb.fit(X, y).predict(X)

Expand Down
80 changes: 50 additions & 30 deletions examples/mlcomp_sparse_document_classification.py
Expand Up @@ -12,7 +12,8 @@
http://mlcomp.org/datasets/379
Once downloaded unzip the arhive somewhere on your filesystem. For instance in::
Once downloaded unzip the archive somewhere on your filesystem.
For instance in::
% mkdir -p ~/data/mlcomp
% cd ~/data/mlcomp
Expand Down Expand Up @@ -49,6 +50,8 @@
from scikits.learn.linear_model.sparse import SGDClassifier
from scikits.learn.metrics import confusion_matrix
from scikits.learn.metrics import classification_report
from scikits.learn.naive_bayes import MultinomialNB


if 'MLCOMP_DATASETS_HOME' not in os.environ:
print "Please follow those instructions to get started:"
Expand All @@ -71,20 +74,6 @@
assert sp.issparse(X_train)
y_train = news_train.target

print "Training a linear classifier..."
parameters = {
'loss': 'hinge',
'penalty': 'l2',
'n_iter': 50,
'alpha': 0.00001,
'fit_intercept': True,
}
print "parameters:", parameters
t0 = time()
clf = SGDClassifier(**parameters).fit(X_train, y_train)
print "done in %fs" % (time() - t0)
print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100)

print "Loading 20 newsgroups test set... "
news_test = load_mlcomp('20news-18828', 'test')
t0 = time()
Expand All @@ -101,22 +90,53 @@
print "done in %fs" % (time() - t0)
print "n_samples: %d, n_features: %d" % X_test.shape

print "Predicting the outcomes of the testing set"
t0 = time()
pred = clf.predict(X_test)
print "done in %fs" % (time() - t0)
################################################################################
# Benchmark classifiers
def benchmark(clf_class, params, name):
print "parameters:", params
t0 = time()
clf = clf_class(**params).fit(X_train, y_train)
print "done in %fs" % (time() - t0)

if hasattr(clf, 'coef_'):
print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100)

print "Predicting the outcomes of the testing set"
t0 = time()
pred = clf.predict(X_test)
print "done in %fs" % (time() - t0)

print "Classification report on test set for classifier:"
print clf
print
print classification_report(y_test, pred, target_names=news_test.target_names)

cm = confusion_matrix(y_test, pred)
print "Confusion matrix:"
print cm

# Show confusion matrix
pl.matshow(cm)
pl.title('Confusion matrix of the %s classifier' % name)
pl.colorbar()


print "Testbenching a linear classifier..."
parameters = {
'loss': 'hinge',
'penalty': 'l2',
'n_iter': 50,
'alpha': 0.00001,
'fit_intercept': True,
}

benchmark(SGDClassifier, parameters, 'SGD')

print "Classification report on test set for classifier:"
print clf
print
print classification_report(y_test, pred, target_names=news_test.target_names)
print "Testbenching a MultinomialNB classifier..."
parameters = {
'alpha': 0.01
}

cm = confusion_matrix(y_test, pred)
print "Confusion matrix:"
print cm
benchmark(MultinomialNB, parameters, 'MultinomialNB')

# Show confusion matrix
pl.matshow(cm)
pl.title('Confusion matrix')
pl.colorbar()
pl.show()

0 comments on commit df1fda0

Please sign in to comment.