From 29ab7332e4787db671e24d788ff8f1cc8c1af6af Mon Sep 17 00:00:00 2001 From: Gael varoquaux Date: Wed, 1 Sep 2010 22:20:04 +0200 Subject: [PATCH 1/8] BUG: Fix warnings module not imported in coordinate_descent. Thanks to Pietro Berkes --- scikits/learn/glm/coordinate_descent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scikits/learn/glm/coordinate_descent.py b/scikits/learn/glm/coordinate_descent.py index 7a8978839c66d..6a8a93c3f028b 100644 --- a/scikits/learn/glm/coordinate_descent.py +++ b/scikits/learn/glm/coordinate_descent.py @@ -4,6 +4,7 @@ # # License: BSD Style. +import warnings import numpy as np from .base import LinearModel From 2d942d17323de2fcb62272b4ed14b8e1a1456359 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 1 Sep 2010 15:02:57 +0200 Subject: [PATCH 2/8] BUG : fix in Lars at the end of path + more tests (not working yet) --- scikits/learn/glm/lars.py | 7 +++--- scikits/learn/glm/tests/test_lars.py | 33 ++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/scikits/learn/glm/lars.py b/scikits/learn/glm/lars.py index 5064ed0365f6f..7fc78c89f0894 100644 --- a/scikits/learn/glm/lars.py +++ b/scikits/learn/glm/lars.py @@ -100,10 +100,11 @@ def lars_path(X, y, max_iter=None, alpha_min=0, method="lar", precompute=True): Cov = np.ma.dot (Xna, res) imax = np.ma.argmax (np.ma.abs(Cov), fill_value=0.) #rename - Cov_max = (Cov [imax]) + Cov_max = Cov.data[imax] alpha = np.abs(Cov_max) #sum (np.abs(beta[n_iter])) - alphas [n_iter] = np.max(np.abs(np.dot(Xt, res))) #sum (np.abs(beta[n_iter])) + alphas [n_iter] = alpha + if (n_iter >= max_iter or n_pred >= max_pred ): break @@ -185,7 +186,7 @@ def lars_path(X, y, max_iter=None, alpha_min=0, method="lar", precompute=True): n_pred -= 1 drop_idx = active.pop (idx) # please please please remove this masked arrays pain from me - Xna[drop_idx] = Xna.data[drop_idx].copy() + Xna[drop_idx] = Xna.data[drop_idx] print 'dropped ', idx, ' at ', n_iter, ' iteration' Xa = Xt[active] # duplicate L[:n_pred, :n_pred] = linalg.cholesky(np.dot(Xa, Xa.T), lower=True) diff --git a/scikits/learn/glm/tests/test_lars.py b/scikits/learn/glm/tests/test_lars.py index f9ba7159538cb..30691daf1ed4f 100644 --- a/scikits/learn/glm/tests/test_lars.py +++ b/scikits/learn/glm/tests/test_lars.py @@ -41,20 +41,35 @@ def test_1(): ocur = len(cov[ C - eps < abs(cov)]) assert ocur == i + 1 +def test_lasso_gives_lstsq_solution(): + """ + Test that LARS Lasso gives least square solution at the end + of the path + """ + + alphas_, active, coef_path_ = lars_path(X, y, method="lasso") + coef_lstsq = np.linalg.lstsq(X, y)[0] + assert_array_almost_equal(coef_lstsq , coef_path_[:,-1]) + def test_lasso_lars_vs_lasso_cd(): """ Test that LassoLars and Lasso using coordinate descent give the same results """ - lasso_lars = LassoLARS(alpha=0.1) - lasso_lars.fit(X, y) - - # make sure results are the same than with Lasso Coordinate descent - lasso = Lasso(alpha=0.1) - lasso.fit(X, y, maxit=3000, tol=1e-10) - - error = np.linalg.norm(lasso_lars.coef_ - lasso.coef_) - assert error < 1e-5 + lasso_lars = LassoLARS(alpha=0.1, normalize=False) + lasso = Lasso(alpha=0.1, fit_intercept=False) + for alpha in [0.1, 0.01, 0.004]: + lasso_lars.alpha = alpha + lasso_lars.fit(X, y) + lasso.alpha = alpha + lasso.fit(X, y, maxit=5000, tol=1e-13) + + # make sure results are the same than with Lasso Coordinate descent + error = np.linalg.norm(lasso_lars.coef_ - lasso.coef_) + print lasso.coef_ + print lasso_lars.coef_ + print 'Error : ', error + assert error < 1e-5 From cfc090ba7a4d54eeb4c323b88d0b8de273ff9aff Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Wed, 1 Sep 2010 18:16:09 +0200 Subject: [PATCH 3/8] ENH : using explained variance as score for regression problems --- scikits/learn/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index 990f970cb872a..0e4209dbe7103 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -9,8 +9,6 @@ import numpy as np -from .metrics import zero_one, mean_square_error - ################################################################################ class BaseEstimator(object): """ Base class for all estimators in the scikit learn @@ -141,5 +139,6 @@ def score(self, X, y): ------- z : float """ - return - mean_square_error(self.predict(X), y) + return 1 - np.linalg.norm(y - self.predict(X))**2 \ + / np.linalg.norm(y)**2 From b44e1c877d29ae43b26f449fff2f2a86ccf36955 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Thu, 2 Sep 2010 09:20:06 +0200 Subject: [PATCH 4/8] ENH: on the use of explained_variance in mixin regressor class --- scikits/learn/base.py | 6 +++--- scikits/learn/metrics.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scikits/learn/base.py b/scikits/learn/base.py index 0e4209dbe7103..e8fe759a27034 100644 --- a/scikits/learn/base.py +++ b/scikits/learn/base.py @@ -9,6 +9,8 @@ import numpy as np +from .metrics import explained_variance + ################################################################################ class BaseEstimator(object): """ Base class for all estimators in the scikit learn @@ -139,6 +141,4 @@ def score(self, X, y): ------- z : float """ - return 1 - np.linalg.norm(y - self.predict(X))**2 \ - / np.linalg.norm(y)**2 - + return explained_variance(y, self.predict(X)) diff --git a/scikits/learn/metrics.py b/scikits/learn/metrics.py index c5334c9c5da77..5299db4968500 100644 --- a/scikits/learn/metrics.py +++ b/scikits/learn/metrics.py @@ -179,5 +179,5 @@ def explained_variance(y_pred, y_true): """Explained variance returns the explained variance """ - return (np.var(y_true) - np.var(y_true - y_pred)) / np.var(y_true) + return 1 - np.var(y_true - y_pred) / np.var(y_true) From 6e708abb61f8221147cb217cd8f6ead3b29d16db Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 2 Sep 2010 14:18:30 +0200 Subject: [PATCH 5/8] DOC: more work on svm module. --- doc/modules/classes.rst | 33 +++++-------- doc/modules/svm.rst | 104 ++++++++++++++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 41 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 2652ef1653b74..a07ddb4cc25ff 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -18,6 +18,18 @@ Support Vector Machines svm.NuSVR svm.OneClass +.. _sparse_svm_class_reference: + + sparse.svm.SVC + sparse.svm.LinearSVC + sparse.svm.NuSVC + sparse.svm.SVR + sparse.svm.NuSVR + sparse.svm.OneClass + +For sparse data +----------------- + Generalized Linear Models ========================= @@ -29,24 +41,3 @@ Generalized Linear Models glm.Ridge glm.Lasso - - -For sparse data -=============== - -Support Vector Machines ------------------------ - -.. currentmodule:: scikits.learn.sparse - -.. autosummary:: - :toctree: generated/ - :template: class.rst - - svm.SVC - svm.LinearSVC - svm.NuSVC - svm.SVR - svm.NuSVR - svm.OneClassSVM - diff --git a/doc/modules/svm.rst b/doc/modules/svm.rst index 9f2cb4924d1e4..7c01997121d2c 100644 --- a/doc/modules/svm.rst +++ b/doc/modules/svm.rst @@ -125,10 +125,6 @@ implement this is called :class:`OneClassSVM` In this case, as it is a type of unsupervised learning, the fit method will only take as input an array X, as there are no class labels. -.. note:: - - For a complete example on one class SVM see - :ref:`example_svm_plot_oneclass.py` example. .. figure:: ../auto_examples/svm/images/plot_oneclass.png :target: ../auto_examples/svm/plot_oneclass.html @@ -140,8 +136,6 @@ Examples -------- :ref:`example_svm_plot_oneclass.py` -See :ref:`svm_examples` for a complete list of examples. - @@ -151,15 +145,23 @@ Support Vector machines for sparse data ======================================= There is support for sparse data given in any matrix in a format -supported by scipy.sparse. See module scikits.learn.sparse.svm. +supported by scipy.sparse. Classes have the same name, just prefixed +by the `sparse` namespace, and take the same arguments, with the +exception of training and test data, which is expected to be in a +matrix format defined in scipy.sparse. + +For maximum efficiency, use the CSR matrix format as defined in +`scipy.sparse.csr_matrix +`_. -:class:`SVC` +See the complete listing of classes in +:ref:`sparse_svm_class_reference`. Tips on Practical Use ===================== - * Support Vector Machine algorithms are not scale-invariant, so it + * Support Vector Machine algorithms are not scale invariant, so it is highly recommended to scale your data. For example, scale each attribute on the input vector X to [0,1] or [-1,+1], or standarize it to have mean 0 and variance 1. Note that the *same* scaling @@ -168,8 +170,8 @@ Tips on Practical Use `_ for some examples on scaling. - * nu in NuSVC/OneClassSVM/NuSVR approximates the fraction of - training errors and support vectors. + * Parameter nu in NuSVC/OneClassSVM/NuSVR approximates the fraction + of training errors and support vectors. * If data for classification are unbalanced (e.g. many positive and few negative), try different penalty parameters C. @@ -183,13 +185,26 @@ Kernel functions ================ The *kernel function* can be any of the following: + * linear: :math:``. + * polynomial: :math:`(\gamma + r)^d`. d is specified by keyword `degree`. + * rbf (:math:`exp(-\gamma |x-x'|^2), \gamma > 0`). :math:`\gamma` is specified by keyword gamma. + * sigmoid (:math:`tanh( + r)`). +Different kernels are specified by keword kernel at initialization:: + + >>> linear_svc = svm.SVC(kernel='linear') + >>> linear_svc.kernel + 'linear' + >>> rbf_svc = svm.SVC (kernel='rbf') + >>> rbf_svc.kernel + 'rbf' + Custom Kernels -------------- @@ -230,7 +245,7 @@ instance that will use that kernel:: Passing the gram matrix ~~~~~~~~~~~~~~~~~~~~~~~ -set kernel='precomputed' and pass the gram matrix instead of X in the +Set kernel='precomputed' and pass the gram matrix instead of X in the fit method. @@ -258,21 +273,70 @@ generalization error of the classifier. :align: center :scale: 50 - - SVC --- Given training vectors :math:`x_i \in R^n`, i=1,..., l, in two -classes, and a vector :math:`y \in R^l` +classes, and a vector :math:`y \in R^l` such that :math:`y_i \in {1, +-1}`, SVC solves the following primal problem: + + +.. math:: + + \min_ {w, b, \zeta} \frac{1}{2} w^T w + C \sum_{i=1, l} \zeta_i + + + + \textrm {subject to } & y_i (w^T \phi (x_i) + b) \geq 1 - \zeta_i,\\ + & \zeta_i \geq 0, i=1, ..., l + +Its dual is + +.. math:: + + \min_{\alpha} \frac{1}{2} \alpha^T Q \alpha - e^T \alpha + + + \textrm {subject to } & y^T \alpha = 0\\ + & 0 \leq \alpha_i \leq C, i=1, ..., l + +where :math:`e` is the vector of all ones, C > 0 is the upper bound, Q +is an l by l positive semidefinite matrix, :math:`Q_ij \equiv K(x_i, +x_j)` and :math:`\phi (x_i)^T \ phi (x)` is the kernel. Here training +vectors are mapped into a higher (maybe infinite) dimensional space by +the function :math:`\phi` + + +The decision function is: + +.. math:: sgn(\sum_{i=1}^l y_i \alpha_i K(x_i, x) + \rho) + + +.. TODO multiclass case ?/ + +This parameters can accessed through the memebers support\_ and intercept\_: + + - Member support\_ holds the product :math:`y^T \alpha` + + - Member intercept\_ of the classifier holds :math:`-\rho` + +References +~~~~~~~~~~ + +This algorithm is implemented as described in `Automatic Capacity +Tuning of Very Large VC-dimension Classifiers +`_ +and `Support-vector networks +`_ -In SVC The decision function in this case will be: -.. math:: sgn(\sum_{i=1}^l \alpha_i K(x_i, x) + \rho) -where :math:`\alpha, \rho` can be accessed through fields support\_ and -intercept\_ of the classifier instance, respectevely. +NuSVC +----- - - *penalty*. C > 0 is the penalty parameter of the error term. +We introduce a new parameter :math:`\nu` wich controls the number of +support vectors and training errors. The parameter :math:`\nu \in (0, +1]` is an upper bound on the fraction of training errors and a lower +bound of the fraction of support vectors. Implementation details From 7d3a01e69e6667c585395fd585cebf644668ff35 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 2 Sep 2010 15:54:56 +0200 Subject: [PATCH 6/8] Fix in LARS: specify manually number of interations for full path. Later we should implement a way to just give max_features and he will compute the full path, but at least we now have a LassoLARS that gives the same results as the coordinate descent version. --- scikits/learn/glm/lars.py | 12 +++++--- scikits/learn/glm/tests/test_lars.py | 42 +++++++++++++++------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/scikits/learn/glm/lars.py b/scikits/learn/glm/lars.py index 7fc78c89f0894..719db68be0f5c 100644 --- a/scikits/learn/glm/lars.py +++ b/scikits/learn/glm/lars.py @@ -98,9 +98,10 @@ def lars_path(X, y, max_iter=None, alpha_min=0, method="lar", precompute=True): # Calculate covariance matrix and get maximum res = y - np.dot (X, beta[n_iter]) # there are better ways Cov = np.ma.dot (Xna, res) + print Cov - imax = np.ma.argmax (np.ma.abs(Cov), fill_value=0.) #rename - Cov_max = Cov.data[imax] + imax = np.ma.argmax (np.ma.abs(Cov)) #rename + Cov_max = Cov.data [imax] alpha = np.abs(Cov_max) #sum (np.abs(beta[n_iter])) alphas [n_iter] = alpha @@ -316,13 +317,14 @@ class LassoLARS (LinearModel): an alternative optimization strategy called 'coordinate descent.' """ - def __init__(self, alpha=1.0, normalize=True): + def __init__(self, alpha=1.0, max_iter=None, normalize=True): """ XXX : add doc # will only normalize non-zero columns """ self.alpha = alpha self.normalize = normalize self.coef_ = None + self.max_iter = max_iter def fit (self, X, y, **params): """ XXX : add doc @@ -346,7 +348,9 @@ def fit (self, X, y, **params): method = 'lasso' alphas_, active, coef_path_ = lars_path(X, y, - alpha_min=alpha, method=method) + alpha_min=alpha, method=method, + max_iter=self.max_iter) + self.coef_ = coef_path_[:,-1] return self diff --git a/scikits/learn/glm/tests/test_lars.py b/scikits/learn/glm/tests/test_lars.py index 30691daf1ed4f..c7c5772804d9d 100644 --- a/scikits/learn/glm/tests/test_lars.py +++ b/scikits/learn/glm/tests/test_lars.py @@ -4,42 +4,46 @@ from nose.tools import assert_equal, assert_true -from ..lars import lars_path, LeastAngleRegression, LassoLARS +from ..lars import lars_path, LeastAngleRegression, LassoLARS, LARS from ..coordinate_descent import Lasso from scikits.learn import datasets - - n, m = 10, 10 np.random.seed (0) diabetes = datasets.load_diabetes() X, y = diabetes.data, diabetes.target -#normalize data -_xmean = X.mean(0) -_ymean = y.mean(0) -X = X - _xmean -y = y - _ymean -_norms = np.apply_along_axis (np.linalg.norm, 0, X) -nonzeros = np.flatnonzero(_norms) -X[:, nonzeros] /= _norms[nonzeros] - - def test_1(): """ Principle of LARS is to keep covariances tied and decreasing """ - - alphas_, active, coef_path_ = lars_path(X, y, 6, method="lar") + max_pred = 10 + alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target, max_pred, method="lar") for (i, coef_) in enumerate(coef_path_.T): res = y - np.dot(X, coef_) cov = np.dot(X.T, res) C = np.max(abs(cov)) eps = 1e-3 ocur = len(cov[ C - eps < abs(cov)]) - assert ocur == i + 1 + if i < max_pred: + assert ocur == i+1 + else: + # no more than max_pred variables can go into the active set + assert ocur == max_pred + + +def test_lars_lstsq(): + """ + Test that LARS gives least square solution at the end + of the path + """ + # test that it arrives to a least squares solution + alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target, method="lar") + coef_lstsq = np.linalg.lstsq(X, y)[0] + assert_array_almost_equal(coef_path_.T[-1], coef_lstsq) + def test_lasso_gives_lstsq_solution(): """ @@ -47,7 +51,7 @@ def test_lasso_gives_lstsq_solution(): of the path """ - alphas_, active, coef_path_ = lars_path(X, y, method="lasso") + alphas_, active, coef_path_ = lars_path(X, y, max_iter=12, method="lasso") coef_lstsq = np.linalg.lstsq(X, y)[0] assert_array_almost_equal(coef_lstsq , coef_path_[:,-1]) @@ -60,14 +64,12 @@ def test_lasso_lars_vs_lasso_cd(): lasso = Lasso(alpha=0.1, fit_intercept=False) for alpha in [0.1, 0.01, 0.004]: lasso_lars.alpha = alpha - lasso_lars.fit(X, y) + lasso_lars.fit(X, y, max_iter=12) lasso.alpha = alpha lasso.fit(X, y, maxit=5000, tol=1e-13) # make sure results are the same than with Lasso Coordinate descent error = np.linalg.norm(lasso_lars.coef_ - lasso.coef_) - print lasso.coef_ - print lasso_lars.coef_ print 'Error : ', error assert error < 1e-5 From 780b92f29c189363e98fd3221f31bcbb610ad4e8 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 2 Sep 2010 16:03:25 +0200 Subject: [PATCH 7/8] Remove "debugging" traces... --- scikits/learn/glm/lars.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scikits/learn/glm/lars.py b/scikits/learn/glm/lars.py index 719db68be0f5c..f35808b4a4acb 100644 --- a/scikits/learn/glm/lars.py +++ b/scikits/learn/glm/lars.py @@ -98,7 +98,6 @@ def lars_path(X, y, max_iter=None, alpha_min=0, method="lar", precompute=True): # Calculate covariance matrix and get maximum res = y - np.dot (X, beta[n_iter]) # there are better ways Cov = np.ma.dot (Xna, res) - print Cov imax = np.ma.argmax (np.ma.abs(Cov)) #rename Cov_max = Cov.data [imax] From 394fab4c32f73da574694bc9221cfb251efdb61f Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 2 Sep 2010 16:08:41 +0200 Subject: [PATCH 8/8] Fix doctests. --- scikits/learn/glm/lars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scikits/learn/glm/lars.py b/scikits/learn/glm/lars.py index f35808b4a4acb..6dcdfde871db6 100644 --- a/scikits/learn/glm/lars.py +++ b/scikits/learn/glm/lars.py @@ -306,7 +306,7 @@ class LassoLARS (LinearModel): >>> from scikits.learn import glm >>> clf = glm.LassoLARS(alpha=0.1) >>> clf.fit([[-1,1], [0, 0], [1, 1]], [-1, 0, -1]) - LassoLARS(normalize=True, alpha=0.1) + LassoLARS(normalize=True, alpha=0.1, max_iter=None) >>> print clf.coef_ [ 0. -0.51649658]