diff --git a/benchmarks/bench_sgd_covertype.py b/benchmarks/bench_sgd_covertype.py index e08f11820fdba..e42f0cfdf8459 100644 --- a/benchmarks/bench_sgd_covertype.py +++ b/benchmarks/bench_sgd_covertype.py @@ -17,7 +17,7 @@ Classifier train-time test-time error-rate -------------------------------------------- Liblinear 9.4471s 0.0184s 0.2305 - GNB 2.5426s 0.1725s 0.3633 + GaussianNB 2.5426s 0.1725s 0.3633 SGD 0.2137s 0.0047s 0.2300 @@ -57,7 +57,7 @@ from scikits.learn.svm import LinearSVC from scikits.learn.linear_model import SGDClassifier -from scikits.learn.naive_bayes import GNB +from scikits.learn.naive_bayes import GaussianNB from scikits.learn import metrics ###################################################################### @@ -158,8 +158,8 @@ def benchmark(clf): liblinear_err, liblinear_train_time, liblinear_test_time = liblinear_res ###################################################################### -## Train GNB model -gnb_err, gnb_train_time, gnb_test_time = benchmark(GNB()) +## Train GaussianNB model +gnb_err, gnb_train_time, gnb_test_time = benchmark(GaussianNB()) ###################################################################### ## Train SGD model @@ -189,7 +189,7 @@ def print_row(clf_type, train_time, test_time, err): print("-" * 44) print_row("Liblinear", liblinear_train_time, liblinear_test_time, liblinear_err) -print_row("GNB", gnb_train_time, gnb_test_time, gnb_err) +print_row("GaussianNB", gnb_train_time, gnb_test_time, gnb_err) print_row("SGD", sgd_train_time, sgd_test_time, sgd_err) print("") print("") diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 3fcdc4aff60c1..65b258b871119 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -148,7 +148,8 @@ Naive Bayes :toctree: generated/ :template: class.rst - naive_bayes.GNB + naive_bayes.GaussianNB + naive_bayes.MultinomialNB Nearest Neighbors diff --git a/doc/modules/naive_bayes.rst b/doc/modules/naive_bayes.rst index 860aa445ad78b..389ecf94cbe40 100644 --- a/doc/modules/naive_bayes.rst +++ b/doc/modules/naive_bayes.rst @@ -6,31 +6,88 @@ Naive Bayes **Naive Bayes** algorithms are a set of supervised learning methods -based on applying Baye's theorem with strong (naive) independence -assumptions. +based on applying Bayes' theorem with the "naive" assumption of independence +between every pair of features. Given a class variable :math:`c` and a +dependent set of feature variables :math:`f_1` through :math:`f_n`, Bayes' +theorem states the following relationship: -The advantage of Naive Bayes approaches are: +.. math:: - - It requires a small amount of training data to estimate the - parameters necessary for classification. + p(c \mid f_1,\dots,f_n) \propto p(c) p(\mid f_1,\dots,f_n \mid c) - - In spite of their naive design and apparently over-simplified - assumptions, naive Bayes classifiers have worked quite well in - many complex real-world situations. +Using the naive assumption this relationship is simplified: - - The decoupling of the class conditional feature distributions - means that each distribution can be independently estimated as a - one dimensional distribution. This in turn helps to alleviate - problems stemming from the curse of dimensionality. +.. math:: + + p(c \mid f_1,\dots,f_n) \propto p(c) \prod_{i=1}^{n} p(f_i \mid c) + + \Downarrow + + \hat{c} = \arg\max_c p(c) \prod_{i=1}^{n} p(f_i \mid c), + +where we used the Maximum a Posteriori estimator. + +The different naive Bayes classifiers differ by the assumption on the +distribution of :math:`p(f_i \mid c)`: + +In spite of their naive design and apparently over-simplified assumptions, +naive Bayes classifiers have worked quite well in many real-world situations, +famously document classification and spam filtering. They requires a small +amount of training data to estimate the necessary parameters. + +The decoupling of the class conditional feature distributions means that each +distribution can be independently estimated as a one dimensional distribution. +This in turn helps to alleviate problems stemming from the curse of +dimensionality. Gaussian Naive Bayes -------------------- -:class:`GNB` implements the Gaussian Naive Bayes algorithm for classification. +:class:`GaussianNB` implements the Gaussian Naive Bayes algorithm for +classification. The likelihood of the features is assumed to be gaussian: + +.. math:: + p(f_i \mid c) &= \frac{1}{\sqrt{2\pi\sigma^2_c}} \exp^{-\frac{ (f_i - \mu_c)^2}{2\pi\sigma^2_c}} +The parameters of the distribution, :math:`\sigma_c` and :math:`\mu_c` are +estimated using maximum likelihood. .. topic:: Examples: * :ref:`example_naive_bayes.py`, + +Multinomial Naive Bayes +----------------------- + +:class:`MultinomialNB` implements the Multinomial Naive Bayes algorithm for classification. +Multinomial Naive Bayes models the distribution of words in a document as a +multinomial. The distribution is parametrized by the vector +:math:`\overline{\theta_c} = (\theta_{c1},\ldots,\theta_{cn})` where :math:`c` +is the class of document, :math:`n` is the size of the vocabulary and :math:`\theta_{ci}` +is the probability of word :math:`i` appearing in a document of class :math:`c`. +The likelihood of document :math:`d` is, + +.. math:: + + p(d \mid \overline{\theta_c}) &= \frac{ (\sum_i f_i)! }{\prod_i f_i !} \prod_i(\theta_{ci})^{f_i} + +where :math:`f_{i}` is the frequency count of word :math:`i`. It can be shown +that the maximum posterior probability is, + +.. math:: + + \hat{c} = \arg\max_c [ \log p(\overline{\theta_c}) + \sum_i f_i \log \theta_{ci} ] + +The vector of parameters :math:`\overline{\theta_c}` is estimated by a smoothed +version of maximum likelihood, + +.. math:: + + \hat{\theta}_{ci} = \frac{ N_{ci} + \alpha_i }{N_c + \alpha } + +where :math:`N_{ci}` is the number of times word :math:`i` appears in a document +of class :math:`c` and :math:`N_{c}` is the total count of words in a document +of class :math:`c`. The smoothness priors :math:`\alpha_i` and their sum +:math:`\alpha` account for words not seen in the learning samples. diff --git a/examples/covariance/plot_lw_vs_oas.py b/examples/covariance/plot_lw_vs_oas.py index ddce24e75ce51..1e13c185bbe32 100644 --- a/examples/covariance/plot_lw_vs_oas.py +++ b/examples/covariance/plot_lw_vs_oas.py @@ -48,12 +48,12 @@ lw = LedoitWolf(store_precision=False) lw.fit(X, assume_centered=True) - lw_mse[i,j] = lw.error(real_cov) + lw_mse[i,j] = lw.error_norm(real_cov, scaling=False) lw_shrinkage[i,j] = lw.shrinkage_ oa = OAS(store_precision=False) oa.fit(X, assume_centered=True) - oa_mse[i,j] = oa.error(real_cov) + oa_mse[i,j] = oa.error_norm(real_cov, scaling=False) oa_shrinkage[i,j] = oa.shrinkage_ # plot MSE @@ -62,7 +62,7 @@ label='Ledoit-Wolf', color='g') pl.errorbar(n_samples_range, oa_mse.mean(1), yerr=oa_mse.std(1), label='OAS', color='r') -pl.ylabel("MSE") +pl.ylabel("Squared error") pl.legend(loc="upper right") pl.title("Comparison of covariance estimators") pl.xlim(5, 31) diff --git a/examples/document_classification_20newsgroups.py b/examples/document_classification_20newsgroups.py index 228160d080c00..ac2841261123b 100644 --- a/examples/document_classification_20newsgroups.py +++ b/examples/document_classification_20newsgroups.py @@ -45,6 +45,7 @@ from scikits.learn.linear_model import RidgeClassifier from scikits.learn.svm.sparse import LinearSVC from scikits.learn.linear_model.sparse import SGDClassifier +from scikits.learn.naive_bayes import MultinomialNB from scikits.learn import metrics @@ -128,9 +129,10 @@ def benchmark(clf): score = metrics.f1_score(y_test, pred) print "f1-score: %0.3f" % score - nnz = clf.coef_.nonzero()[0].shape[0] - print "non-zero coef: %d" % nnz - print + if hasattr(clf, 'coef_'): + nnz = clf.coef_.nonzero()[0].shape[0] + print "non-zero coef: %d" % nnz + print if print_report: print "classification report:" @@ -165,3 +167,8 @@ def benchmark(clf): print "Elastic-Net penalty" sgd_results = benchmark(SGDClassifier(alpha=.0001, n_iter=50, penalty="elasticnet")) + +# Train sparse MultinomialNB +print 80 * '=' +print "MultinomialNB penalty" +mnnb_results = benchmark(MultinomialNB(alpha=.01)) diff --git a/examples/naive_bayes.py b/examples/gaussian_naive_bayes.py similarity index 77% rename from examples/naive_bayes.py rename to examples/gaussian_naive_bayes.py index 657d386bee037..7c61ae2c1bb3a 100644 --- a/examples/naive_bayes.py +++ b/examples/gaussian_naive_bayes.py @@ -3,7 +3,7 @@ Gaussian Naive Bayes ============================ -A classification example using Gaussian Naive Bayes (GNB). +A classification example using Gaussian Naive Bayes (GaussianNB). """ @@ -18,9 +18,9 @@ y = iris.target ################################################################################ -# GNB -from scikits.learn.naive_bayes import GNB -gnb = GNB() +# GaussianNB +from scikits.learn.naive_bayes import GaussianNB +gnb = GaussianNB() y_pred = gnb.fit(X, y).predict(X) diff --git a/examples/mlcomp_sparse_document_classification.py b/examples/mlcomp_sparse_document_classification.py index 71e65ab2625bc..6080e0ba88308 100644 --- a/examples/mlcomp_sparse_document_classification.py +++ b/examples/mlcomp_sparse_document_classification.py @@ -12,7 +12,8 @@ http://mlcomp.org/datasets/379 -Once downloaded unzip the arhive somewhere on your filesystem. For instance in:: +Once downloaded unzip the archive somewhere on your filesystem. +For instance in:: % mkdir -p ~/data/mlcomp % cd ~/data/mlcomp @@ -49,6 +50,8 @@ from scikits.learn.linear_model.sparse import SGDClassifier from scikits.learn.metrics import confusion_matrix from scikits.learn.metrics import classification_report +from scikits.learn.naive_bayes import MultinomialNB + if 'MLCOMP_DATASETS_HOME' not in os.environ: print "Please follow those instructions to get started:" @@ -71,20 +74,6 @@ assert sp.issparse(X_train) y_train = news_train.target -print "Training a linear classifier..." -parameters = { - 'loss': 'hinge', - 'penalty': 'l2', - 'n_iter': 50, - 'alpha': 0.00001, - 'fit_intercept': True, -} -print "parameters:", parameters -t0 = time() -clf = SGDClassifier(**parameters).fit(X_train, y_train) -print "done in %fs" % (time() - t0) -print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100) - print "Loading 20 newsgroups test set... " news_test = load_mlcomp('20news-18828', 'test') t0 = time() @@ -101,22 +90,53 @@ print "done in %fs" % (time() - t0) print "n_samples: %d, n_features: %d" % X_test.shape -print "Predicting the outcomes of the testing set" -t0 = time() -pred = clf.predict(X_test) -print "done in %fs" % (time() - t0) +################################################################################ +# Benchmark classifiers +def benchmark(clf_class, params, name): + print "parameters:", params + t0 = time() + clf = clf_class(**params).fit(X_train, y_train) + print "done in %fs" % (time() - t0) + + if hasattr(clf, 'coef_'): + print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100) + + print "Predicting the outcomes of the testing set" + t0 = time() + pred = clf.predict(X_test) + print "done in %fs" % (time() - t0) + + print "Classification report on test set for classifier:" + print clf + print + print classification_report(y_test, pred, target_names=news_test.target_names) + + cm = confusion_matrix(y_test, pred) + print "Confusion matrix:" + print cm + + # Show confusion matrix + pl.matshow(cm) + pl.title('Confusion matrix of the %s classifier' % name) + pl.colorbar() + + +print "Testbenching a linear classifier..." +parameters = { + 'loss': 'hinge', + 'penalty': 'l2', + 'n_iter': 50, + 'alpha': 0.00001, + 'fit_intercept': True, +} + +benchmark(SGDClassifier, parameters, 'SGD') -print "Classification report on test set for classifier:" -print clf -print -print classification_report(y_test, pred, target_names=news_test.target_names) +print "Testbenching a MultinomialNB classifier..." +parameters = { + 'alpha': 0.01 +} -cm = confusion_matrix(y_test, pred) -print "Confusion matrix:" -print cm +benchmark(MultinomialNB, parameters, 'MultinomialNB') -# Show confusion matrix -pl.matshow(cm) -pl.title('Confusion matrix') -pl.colorbar() pl.show() diff --git a/scikits/learn/covariance/empirical_covariance_.py b/scikits/learn/covariance/empirical_covariance_.py index 62807ac9690a0..80860ef56150d 100644 --- a/scikits/learn/covariance/empirical_covariance_.py +++ b/scikits/learn/covariance/empirical_covariance_.py @@ -165,7 +165,8 @@ def score(self, X_test, assume_centered=False): return res - def error(self, comp_cov, error_type='mse'): + def error_norm(self, comp_cov, norm='frobenius', scaling=True, + squared=True): """Computes the Mean Squared Error between two covariance estimators. (In the sense of the Frobenius norm) @@ -173,11 +174,18 @@ def error(self, comp_cov, error_type='mse'): ---------- comp_cov: array-like, shape = [n_features, n_features] The covariance which to be compared to. - error_type: str - The type of error. Available error types: - - 'mse': Mean Squared Error (default) = tr(A^t.A) / n_features - - 'rmse': Root Mean Squared Error = sqrt(tr(A^t.A) / n_features - - 'sse': Sum of Squared Errors = tr(A^t.A) + norm: str + The type of norm used to compute the error. Available error types: + - 'frobenius' (default): sqrt(tr(A^t.A)) + - 'spectral': sqrt(max(eigenvalues(A^t.A)) + where A is the error (comp_cov - self.covariance_). + scaling: bool + If True (default), the squared error norm is divided by n_features + If False, the squared error norm is not rescaled + squared: bool + Whether to compute the squared error norm or the error norm. + If True (default), the squared error norm is returned. + If False, the error norm is returned. Returns ------- @@ -185,15 +193,23 @@ def error(self, comp_cov, error_type='mse'): `self` and `comp_cov` covariance estimators. """ - diff = comp_cov - self.covariance_ - sse = np.sum(diff ** 2) - if error_type == 'mse': - error = sse / diff.shape[0] - elif error_type == 'rmse': - error = np.sqrt(sse / diff.shape[0]) - elif error_type == 'sse': - error = sse + # compute the error + error = comp_cov - self.covariance_ + # compute the error norm + if norm == "frobenius": + squared_norm = np.sum(error ** 2) + elif norm == "spectral": + squared_norm = np.amax(linalg.svdvals(np.dot(error.T, error))) else: - raise Exception('Error type \"%s\" not implemented yet' %error_type) + raise NotImplementedError( + "Only spectral and frobenius norms are implemented") + # optionaly scale the error norm + if scaling: + squared_norm = squared_norm / error.shape[0] + # finally get either the squared norm or the norm + if squared: + result = squared_norm + else: + result = np.sqrt(squared_norm) - return error + return result diff --git a/scikits/learn/covariance/tests/test_covariance.py b/scikits/learn/covariance/tests/test_covariance.py index 09b475607084d..28463cab42d50 100644 --- a/scikits/learn/covariance/tests/test_covariance.py +++ b/scikits/learn/covariance/tests/test_covariance.py @@ -23,22 +23,18 @@ def test_covariance(): cov = EmpiricalCovariance() cov.fit(X) assert_array_almost_equal(empirical_covariance(X), cov.covariance_, 4) - assert_almost_equal(cov.error(empirical_covariance(X)), 0) + assert_almost_equal(cov.error_norm(empirical_covariance(X)), 0) assert_almost_equal( - cov.error(empirical_covariance(X), error_type='rmse'), 0) - assert_almost_equal( - cov.error(empirical_covariance(X), error_type='sse'), 0) + cov.error_norm(empirical_covariance(X), norm='spectral'), 0) # test with n_features = 1 X_1d = X[:,0] cov = EmpiricalCovariance() cov.fit(X_1d) assert_array_almost_equal(empirical_covariance(X_1d), cov.covariance_, 4) - assert_almost_equal(cov.error(empirical_covariance(X_1d)), 0) - assert_almost_equal( - cov.error(empirical_covariance(X_1d), error_type='rmse'), 0) + assert_almost_equal(cov.error_norm(empirical_covariance(X_1d)), 0) assert_almost_equal( - cov.error(empirical_covariance(X_1d), error_type='sse'), 0) + cov.error_norm(empirical_covariance(X_1d), norm='spectral'), 0) # test integer type X_integer = np.asarray([[0,1],[1,0]]) diff --git a/scikits/learn/cross_val.py b/scikits/learn/cross_val.py index 90061ec7f289d..7bb0c5d58de85 100644 --- a/scikits/learn/cross_val.py +++ b/scikits/learn/cross_val.py @@ -12,7 +12,6 @@ from .utils.extmath import factorial, combinations from .utils.fixes import unique from .utils import check_arrays -from .utils import check_random_state from .externals.joblib import Parallel, delayed @@ -559,7 +558,7 @@ def __init__(self, n, n_bootstraps=3, n_train=0.5, n_test=None, (self.n_train, n)) if isinstance(n_test, float) and n_test >= 0.0 and n_test <= 1.0: - self.n_test = ceil(test * n) + self.n_test = ceil(n_test * n) elif isinstance(n_test, int): self.n_test = n_test elif n_test is None: @@ -652,9 +651,9 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None, iid=False, if cv is None: indices = hasattr(X, 'tocsr') if y is not None and is_classifier(estimator): - cv = StratifiedKFold(y, k=3, indices=True) + cv = StratifiedKFold(y, k=3, indices=indices) else: - cv = KFold(n_samples, k=3, indices=True) + cv = KFold(n_samples, k=3, indices=indices) if score_func is None: assert hasattr(estimator, 'score'), ValueError( "If no score_func is specified, the estimator passed " diff --git a/scikits/learn/datasets/lfw.py b/scikits/learn/datasets/lfw.py index 4d3775c4607d1..546becd770463 100644 --- a/scikits/learn/datasets/lfw.py +++ b/scikits/learn/datasets/lfw.py @@ -44,6 +44,9 @@ from .base import Bunch +logger = logging.getLogger(__name__) + + BASE_URL = "http://vis-www.cs.umass.edu/lfw/" ARCHIVE_NAME = "lfw.tgz" FUNNELED_ARCHIVE_NAME = "lfw-funneled.tgz" @@ -89,7 +92,7 @@ def check_fetch_lfw(data_home=None, funneled=True, download_if_missing=True): if not exists(target_filepath): if download_if_missing: url = BASE_URL + target_filename - logging.warn("Downloading LFW metadata: %s", url) + logger.warn("Downloading LFW metadata: %s", url) downloader = urllib.urlopen(BASE_URL + target_filename) data = downloader.read() open(target_filepath, 'wb').write(data) @@ -100,7 +103,7 @@ def check_fetch_lfw(data_home=None, funneled=True, download_if_missing=True): if not exists(archive_path): if download_if_missing: - logging.warn("Downloading LFW data (~200MB): %s", archive_url) + logger.warn("Downloading LFW data (~200MB): %s", archive_url) downloader = urllib.urlopen(archive_url) data = downloader.read() # don't open file until download is complete @@ -109,7 +112,7 @@ def check_fetch_lfw(data_home=None, funneled=True, download_if_missing=True): raise IOError("%s is missing" % target_filepath) import tarfile - logging.info("Decompressing the data archive to %s", data_folder_path) + logger.info("Decompressing the data archive to %s", data_folder_path) tarfile.open(archive_path, "r:gz").extractall(path=lfw_home) remove(archive_path) @@ -148,7 +151,7 @@ def _load_imgs(file_paths, slice_, color, resize): # arrays for i, file_path in enumerate(file_paths): if i % 1000 == 0: - logging.info("Loading face #%05d / %05d", i + 1, n_faces) + logger.info("Loading face #%05d / %05d", i + 1, n_faces) face = np.asarray(imread(file_path)[slice_], dtype=np.float32) face /= 255.0 # scale uint8 coded colors to the [0.0, 1.0] floats if resize is not None: @@ -260,7 +263,7 @@ def fetch_lfw_people(data_home=None, funneled=True, resize=0.5, lfw_home, data_folder_path = check_fetch_lfw( data_home=data_home, funneled=funneled, download_if_missing=download_if_missing) - logging.info('Loading LFW people faces from %s', lfw_home) + logger.info('Loading LFW people faces from %s', lfw_home) # wrap the loader in a memoizing function that will return memmaped data # arrays for optimal memory usage @@ -398,7 +401,7 @@ def fetch_lfw_pairs(subset='train', data_home=None, funneled=True, resize=0.5, lfw_home, data_folder_path = check_fetch_lfw( data_home=data_home, funneled=funneled, download_if_missing=download_if_missing) - logging.info('Loading %s LFW pairs from %s', subset, lfw_home) + logger.info('Loading %s LFW pairs from %s', subset, lfw_home) # wrap the loader in a memoizing function that will return memmaped data # arrays for optimal memory usage diff --git a/scikits/learn/datasets/mldata.py b/scikits/learn/datasets/mldata.py index ef273d0689efb..ab07930513026 100755 --- a/scikits/learn/datasets/mldata.py +++ b/scikits/learn/datasets/mldata.py @@ -5,8 +5,9 @@ from scipy import io +import os from os.path import join, exists -from os import makedirs +from shutil import copyfileobj import urllib2 from .base import get_data_home, Bunch @@ -103,22 +104,27 @@ def fetch_mldata(dataname, target_name='label', data_name='data', data_home = get_data_home(data_home=data_home) data_home = join(data_home, 'mldata') if not exists(data_home): - makedirs(data_home) + os.makedirs(data_home) matlab_name = dataname + '.mat' filename = join(data_home, matlab_name) # if the file does not exist, download it if not exists(filename): - urlname = MLDATA_BASE_URL % (dataname) + urlname = MLDATA_BASE_URL % urllib2.quote(dataname) try: mldata_url = urllib2.urlopen(urlname) - except urllib2.URLError: - msg = "Dataset '%s' not found on mldata.org." % dataname - raise IOError(msg) + except urllib2.HTTPError, e: + if e.code == 404: + e.msg = "Dataset '%s' not found on mldata.org." % dataname + raise # store Matlab file - with open(filename, 'w+b') as matlab_file: - matlab_file.write(mldata_url.read()) + try: + with open(filename, 'w+b') as matlab_file: + copyfileobj(mldata_url, matlab_file) + except: + os.remove(filename) + raise mldata_url.close() # load dataset matlab file diff --git a/scikits/learn/datasets/twenty_newsgroups.py b/scikits/learn/datasets/twenty_newsgroups.py index 7dd3de4dacb91..7b92383df1743 100644 --- a/scikits/learn/datasets/twenty_newsgroups.py +++ b/scikits/learn/datasets/twenty_newsgroups.py @@ -44,6 +44,9 @@ from .base import load_filenames +logger = logging.getLogger(__name__) + + URL = ("http://people.csail.mit.edu/jrennie/" "20Newsgroups/20news-bydate.tar.gz") ARCHIVE_NAME = "20news-bydate.tar.gz" @@ -98,13 +101,13 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None, if not os.path.exists(archive_path): if download_if_missing: - logging.warn("Downloading dataset from %s (14 MB)", URL) + logger.warn("Downloading dataset from %s (14 MB)", URL) opener = urllib.urlopen(URL) open(archive_path, 'wb').write(opener.read()) else: raise IOError("%s is missing" % archive_path) - logging.info("Decompressing %s", archive_path) + logger.info("Decompressing %s", archive_path) tarfile.open(archive_path, "r:gz").extractall(path=twenty_home) os.remove(archive_path) diff --git a/scikits/learn/naive_bayes.py b/scikits/learn/naive_bayes.py index d03d47bf2e64a..78a55ff3916ca 100644 --- a/scikits/learn/naive_bayes.py +++ b/scikits/learn/naive_bayes.py @@ -1,18 +1,34 @@ -""" Naives Bayes classifiers. +""" +Naive Bayes models +================== + +Naive Bayes algorithms are a set of supervised learning methods based on +applying Bayes' theorem with strong (naive) feature independence assumptions. + +See http://scikit-learn.sourceforge.net/modules/naive_bayes.html for +complete documentation. """ # Author: Vincent Michel # Minor fixes by Fabian Pedregosa +# MultinomialNB classifier by: +# Amit Aides +# Yehuda Finkelstein +# Lars Buitinck +# (parts based on earlier work by Mathieu Blondel) # # License: BSD Style. -import numpy as np from .base import BaseEstimator, ClassifierMixin +from .utils import safe_asanyarray +from .utils.extmath import safe_sparse_dot +import numpy as np +from scipy.sparse import issparse -class GNB(BaseEstimator, ClassifierMixin): +class GaussianNB(BaseEstimator, ClassifierMixin): """ - Gaussian Naive Bayes (GNB) + Gaussian Naive Bayes (GaussianNB) Parameters ---------- @@ -25,7 +41,7 @@ class GNB(BaseEstimator, ClassifierMixin): Attributes ---------- - proba_y : array, shape = [n_classes] + class_prior : array, shape = [n_classes] probability of each class. theta : array, shape [n_classes * n_features] @@ -34,7 +50,6 @@ class GNB(BaseEstimator, ClassifierMixin): sigma : array, shape [n_classes * n_features] variance of each feature for the different class - Methods ------- fit(X, y) : self @@ -55,10 +70,10 @@ class GNB(BaseEstimator, ClassifierMixin): >>> import numpy as np >>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) >>> Y = np.array([1, 1, 1, 2, 2, 2]) - >>> from scikits.learn.naive_bayes import GNB - >>> clf = GNB() + >>> from scikits.learn.naive_bayes import GaussianNB + >>> clf = GaussianNB() >>> clf.fit(X, Y) - GNB() + GaussianNB() >>> print clf.predict([[-0.8, -1]]) [1] @@ -79,7 +94,6 @@ def fit(self, X, y): y : array-like, shape = [n_samples] Target values. - Returns ------- self : object @@ -91,15 +105,15 @@ def fit(self, X, y): theta = [] sigma = [] - proba_y = [] + class_prior = [] unique_y = np.unique(y) for yi in unique_y: theta.append(np.mean(X[y == yi, :], 0)) sigma.append(np.var(X[y == yi, :], 0)) - proba_y.append(np.float(np.sum(y == yi)) / np.size(y)) + class_prior.append(np.float(np.sum(y == yi)) / np.size(y)) self.theta = np.array(theta) self.sigma = np.array(sigma) - self.proba_y = np.array(proba_y) + self.class_prior = np.array(class_prior) self.unique_y = unique_y return self @@ -121,8 +135,8 @@ def predict(self, X): def _joint_log_likelihood(self, X): joint_log_likelihood = [] - for i in range(np.size(self.unique_y)): - jointi = np.log(self.proba_y[i]) + for i in xrange(np.size(self.unique_y)): + jointi = np.log(self.class_prior[i]) n_ij = - 0.5 * np.sum(np.log(np.pi * self.sigma[i, :])) n_ij -= 0.5 * np.sum(((X - self.theta[i, :]) ** 2) / \ (self.sigma[i, :]), 1) @@ -176,3 +190,213 @@ def predict_log_proba(self, X): aB[sup] = np.exp(logaB[sup]) log_proba -= np.log(np.sum(aB, axis=1))[:, np.newaxis] + B return log_proba + + +def asanyarray_or_csr(X): + if issparse(X): + return X.tocsr(), True + else: + return np.asanyarray(X), False + + +def atleast2d_or_csr(X): + if issparse(X): + return X.tocsr() + else: + return np.atleast_2d(X) + + +class MultinomialNB(BaseEstimator, ClassifierMixin): + """ + Naive Bayes classifier for multinomial models + + The multinomial Naive Bayes classifier is suitable for classification with + discrete features (e.g., word counts for text classification). The + multinomial distribution normally requires integer feature counts. However, + in practice, fractional counts such as tf-idf may also work. + + This class is designed to handle both dense and sparse data; it will enter + "sparse mode" if its training matrix (X) is a sparse matrix. + + Parameters + ---------- + alpha: float, optional (default=1.0) + Additive (Laplace/Lidstone) smoothing parameter + (0 for no smoothing). + fit_prior: boolean + Whether to learn class prior probabilities or not. + + Methods + ------- + fit(X, y) : self + Fit the model + + predict(X) : array + Predict using the model. + + predict_proba(X) : array + Predict the probability of each class using the model. + + predict_log_proba(X) : array + Predict the log probability of each class using the model. + + Attributes + ---------- + class_log_prior_, intercept_ : array, shape = [n_classes] + Log probability of each class (smoothed). + + feature_log_prob_, coef_ : array, shape = [n_classes, n_features] + Empirical log probability of features given a class, P(x_i|y). + + (class_log_prior_ and feature_log_prob_ are properties referring to + intercept_ and feature_log_prob_, respectively.) + + Examples + -------- + >>> import numpy as np + >>> X = np.random.randint(5, size=(6, 100)) + >>> Y = np.array([1, 2, 3, 4, 5, 6]) + >>> from scikits.learn.naive_bayes import MultinomialNB + >>> clf = MultinomialNB() + >>> clf.fit(X, Y) + MultinomialNB(alpha=1.0, fit_prior=True) + >>> print clf.predict(X[2]) + [3] + + References + ---------- + For the rationale behind the names coef_ and intercept_, i.e. naive Bayes + as a linear classifier, see J. Rennie et al. (2003), Tackling the poor + assumptions of naive Bayes text classifiers, Proc. ICML. + """ + + def __init__(self, alpha=1.0, fit_prior=True): + self.alpha = alpha + self.fit_prior = fit_prior + + def fit(self, X, y, class_prior=None): + """Fit Multinomial Naive Bayes according to X, y + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Training vectors, where n_samples is the number of samples and + n_features is the number of features. X may be a sparse matrix. + + y : array-like, shape = [n_samples] + Target values. + + class_prior : array, shape [n_classes] + Prior probability per class. + + Returns + ------- + self : object + Returns self. + """ + X, self.sparse = asanyarray_or_csr(X) + y = safe_asanyarray(y) + + self.unique_y = np.unique(y) + n_classes = self.unique_y.size + + self.intercept_ = None + if not self.fit_prior: + self.intercept_ = np.ones(n_classes) / n_classes + if class_prior: + assert len(class_prior) == n_classes, \ + 'Number of priors must match number of classs' + self.intercept_ = np.array(class_prior) + + # N_c is the count of all features in all samples of class c. + # N_c_i is the a count of feature i in all samples of class c. + N_c_i_temp = [] + if self.intercept_ is None: + class_prior = [] + + for yi in self.unique_y: + if self.sparse: + row_ind = np.nonzero(y == yi)[0] + N_c_i_temp.append(np.array(X[row_ind, :].sum(axis=0)).ravel()) + else: + N_c_i_temp.append(np.sum(X[y == yi, :], 0)) + if self.intercept_ is None: + class_prior.append(np.float(np.sum(y == yi)) / y.size) + + N_c_i = np.array(N_c_i_temp) + N_c = np.sum(N_c_i, axis=1) + + # Estimate (and smooth) the parameters of the distribution + # + self.coef_ = (np.log(N_c_i + self.alpha) + - np.log(N_c.reshape(-1, 1) + + self.alpha * X.shape[1])) + if self.intercept_ is None: + self.intercept_ = np.log(class_prior) + + return self + + class_log_prior_ = property(lambda self: self.intercept__) + feature_log_prob_ = property(lambda self: self.coef_) + + def predict(self, X): + """ + Perform classification on an array of test vectors X. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array, shape = [n_samples] + """ + joint_log_likelihood = self._joint_log_likelihood(X) + y_pred = self.unique_y[np.argmax(joint_log_likelihood, axis=0)] + + return y_pred + + def _joint_log_likelihood(self, X): + """Calculate the posterior log probability of the samples X""" + + X = atleast2d_or_csr(X) + + jll = safe_sparse_dot(self.coef_, X.T) + return jll + np.atleast_2d(self.intercept_).T + + def predict_proba(self, X): + """ + Return probability estimates for the test vector X. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array-like, shape = [n_samples, n_classes] + Returns the probability of the sample for each class in + the model, where classes are ordered by arithmetical + order. + """ + return np.exp(self.predict_log_proba(X)) + + def predict_log_proba(self, X): + """ + Return log-probability estimates for the test vector X. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + + Returns + ------- + C : array-like, shape = [n_samples, n_classes] + Returns the log-probability of the sample for each class + in the model, where classes are ordered by arithmetical + order. + """ + jll = self._joint_log_likelihood(X) + # normalize by P(x) = P(f_1, ..., f_n) + log_prob_x = np.logaddexp.reduce(jll[:, np.newaxis]) + return jll - log_prob_x diff --git a/scikits/learn/svm/src/libsvm/svm.cpp b/scikits/learn/svm/src/libsvm/svm.cpp index b65de7203cf9e..7fb2109d1bb8a 100644 --- a/scikits/learn/svm/src/libsvm/svm.cpp +++ b/scikits/learn/svm/src/libsvm/svm.cpp @@ -3038,7 +3038,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param { max_nr_class *= 2; label = (int *)realloc(label,max_nr_class*sizeof(int)); - count = (double *)realloc(count,max_nr_class*sizeof(int)); + count = (double *)realloc(count,max_nr_class*sizeof(double)); } label[nr_class] = this_label; diff --git a/scikits/learn/tests/test_naive_bayes.py b/scikits/learn/tests/test_naive_bayes.py index 6f6fe34a722c1..d22c7b70ca019 100644 --- a/scikits/learn/tests/test_naive_bayes.py +++ b/scikits/learn/tests/test_naive_bayes.py @@ -1,4 +1,7 @@ +import cPickle as pickle +from cStringIO import StringIO import numpy as np +import scipy.sparse from numpy.testing import assert_array_equal, assert_array_almost_equal from .. import naive_bayes @@ -12,14 +15,94 @@ def test_gnb(): """ Gaussian Naive Bayes classification. - This checks that GNB implements fit and predict and returns + This checks that GaussianNB implements fit and predict and returns correct values for a simple toy dataset. """ - clf = naive_bayes.GNB() + clf = naive_bayes.GaussianNB() y_pred = clf.fit(X, y).predict(X) assert_array_equal(y_pred, y) y_pred_proba = clf.predict_proba(X) y_pred_log_proba = clf.predict_log_proba(X) assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8) + + +# Data is 6 random points in an 100 dimensional space classified to +# three classes. +X2 = np.random.randint(5, size=(6, 100)) +y2 = np.array([1, 1, 2, 2, 3, 3]) + + +def test_mnnb(): + """ + Multinomial Naive Bayes classification. + + This checks that MultinomialNB implements fit and predict and returns + correct values for a simple toy dataset. + """ + + # + # Check the ability to predict the learning set. + # + clf = naive_bayes.MultinomialNB() + y_pred = clf.fit(X2, y2).predict(X2) + + assert_array_equal(y_pred, y2) + + # + # Verify that np.log(clf.predict_proba(X)) gives the same results as + # clf.predict_log_proba(X) + # + y_pred_proba = clf.predict_proba(X2) + y_pred_log_proba = clf.predict_log_proba(X2) + assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8) + + +def test_sparse_mnnb(): + """ + Multinomial Naive Bayes classification for sparse data. + + This checks that MultinomialNB implements fit and predict and returns + correct values for a simple toy dataset. + """ + + X2S = scipy.sparse.csr_matrix(X2) + + # + # Check the ability to predict the learning set. + # + clf = naive_bayes.MultinomialNB() + y_pred = clf.fit(X2S, y2).predict(X2S) + + assert_array_equal(y_pred, y2) + + # + # Verify that np.log(clf.predict_proba(X)) gives the same results as + # clf.predict_log_proba(X) + # + y_pred_proba = clf.predict_proba(X2S) + y_pred_log_proba = clf.predict_log_proba(X2S) + assert_array_almost_equal(np.log(y_pred_proba), y_pred_log_proba, 8) + + +def test_mnnb_pickle(): + '''Test picklability of multinomial NB''' + + clf = naive_bayes.MultinomialNB(alpha=2, fit_prior=False).fit(X, y) + y_pred = clf.predict(X) + + store = StringIO() + pickle.dump(clf, store) + clf = pickle.load(StringIO(store.getvalue())) + + assert_array_equal(y_pred, clf.predict(X)) + + +def test_mnnb_predict_proba(): + '''Test multinomial NB's probability scores''' + + clf = naive_bayes.MultinomialNB().fit([[0,1], [0,1], [1,0]], [0,0,1]) + assert clf.predict([0,1]) == 0 + assert np.sum(clf.predict_proba([0,1])) == 1 + assert np.sum(clf.predict_proba([1,0])) == 1