Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

addition of r2_score function #24

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions scikits/learn/base.py
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from .metrics import explained_variance_score
from .metrics import r2_score

################################################################################
def clone(estimator, safe=True):
Expand Down Expand Up @@ -236,7 +236,7 @@ class RegressorMixin(object):
"""

def score(self, X, y):
""" Returns the explained variance of the prediction
""" Returns the coefficient of determination of the prediction

Parameters
----------
Expand All @@ -249,7 +249,7 @@ def score(self, X, y):
-------
z : float
"""
return explained_variance_score(y, self.predict(X))
return r2_score(y, self.predict(X))


################################################################################
Expand Down
2 changes: 1 addition & 1 deletion scikits/learn/grid_search.py
Expand Up @@ -176,7 +176,7 @@ class GridSearchCV(BaseEstimator):
>>> svr = SVR()
>>> clf = GridSearchCV(svr, parameters, n_jobs=1)
>>> clf.fit(X, y).predict([[-0.8, -1]])
array([ 1.14])
array([ 1.13101459])
"""

def __init__(self, estimator, param_grid, loss_func=None, score_func=None,
Expand Down
6 changes: 3 additions & 3 deletions scikits/learn/linear_model/base.py
Expand Up @@ -12,7 +12,7 @@
import numpy as np

from ..base import BaseEstimator, RegressorMixin
from ..metrics import explained_variance_score
from ..metrics import r2_score

###
### TODO: intercept for all models
Expand Down Expand Up @@ -41,9 +41,9 @@ def predict(self, X):
X = np.asanyarray(X)
return np.dot(X, self.coef_) + self.intercept_

def _explained_variance(self, X, y):
def _r2_score(self, X, y):
"""Compute explained variance a.k.a. r^2"""
return explained_variance_score(y, self.predict(X))
return r2_score(y, self.predict(X))

@staticmethod
def _center_data(X, y, fit_intercept):
Expand Down
4 changes: 2 additions & 2 deletions scikits/learn/linear_model/bayes.py
Expand Up @@ -207,7 +207,7 @@ def fit(self, X, y, **params):

self._set_intercept(Xmean, ymean)
# Store explained variance for __str__
self.explained_variance_ = self._explained_variance(X, y)
self.r2_score_ = self._r2_score(X, y)
return self


Expand Down Expand Up @@ -420,5 +420,5 @@ def fit(self, X, y, **params):

self._set_intercept(Xmean, ymean)
# Store explained variance for __str__
self.explained_variance_ = self._explained_variance(X, y)
self.r2_score_ = self._r2_score(X, y)
return self
4 changes: 2 additions & 2 deletions scikits/learn/linear_model/coordinate_descent.py
Expand Up @@ -102,7 +102,7 @@ def fit(self, X, y, maxit=1000, tol=1e-4, coef_init=None, **params):
' to increase the number of interations')

# Store explained variance for __str__
self.explained_variance_ = self._explained_variance(X, y)
self.r2_score_ = self._r2_score(X, y)

# return self for chaining fit and predict calls
return self
Expand Down Expand Up @@ -354,7 +354,7 @@ def fit(self, X, y, cv=None, **fit_params):

self.coef_ = model.coef_
self.intercept_ = model.intercept_
self.explained_variance_ = model.explained_variance_
self.r2_score_ = model.r2_score_
self.alpha = model.alpha
self.alphas = np.asarray(alphas)
return self
Expand Down
21 changes: 20 additions & 1 deletion scikits/learn/metrics.py
Expand Up @@ -499,7 +499,7 @@ def precision_recall_curve(y_true, probas_pred):
def explained_variance_score(y_true, y_pred):
"""Explained variance regression score function

Best possible score is 1.0, lower values are worst.
Best possible score is 1.0, lower values are worse.

Note: the explained variance is not a symmetric function.

Expand All @@ -512,6 +512,25 @@ def explained_variance_score(y_true, y_pred):
y_pred : array-like
"""
return 1 - np.var(y_true - y_pred) / np.var(y_true)


def r2_score(y_true, y_pred):
"""R^2 (coefficient of determination) regression score function

Best possible score is 1.0, lower values are worse.

Note: not a symmetric function.

return the R^2 score

Parameters
----------
y_true : array-like

y_pred : array-like
"""
return 1 - ((y_true - y_pred)**2).sum() / ((y_true - y_true.mean())**2).sum()



###############################################################################
Expand Down
10 changes: 8 additions & 2 deletions scikits/learn/tests/test_metrics.py
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import nose

from numpy.testing import assert_
from nose.tools import assert_true
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal
from numpy.testing import assert_equal, assert_almost_equal
Expand All @@ -13,6 +13,7 @@
from ..metrics import classification_report
from ..metrics import confusion_matrix
from ..metrics import explained_variance_score
from ..metrics import r2_score
from ..metrics import f1_score
from ..metrics import mean_square_error
from ..metrics import precision_recall_curve
Expand Down Expand Up @@ -222,6 +223,9 @@ def test_losses():
assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2)
assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2)

assert_almost_equal(r2_score(y_true, y_pred), -0.04, 2)
assert_almost_equal(r2_score(y_true, y_true), 1.00, 2)


def test_symmetry():
"""Test the symmetry of score and loss functions"""
Expand All @@ -233,8 +237,10 @@ def test_symmetry():
assert_almost_equal(mean_square_error(y_true, y_pred),
mean_square_error(y_pred, y_true))
# not symmetric
assert_(explained_variance_score(y_true, y_pred) != \
assert_true(explained_variance_score(y_true, y_pred) != \
explained_variance_score(y_pred, y_true))
assert_true(r2_score(y_true, y_pred) != \
r2_score(y_pred, y_true))
# FIXME: precision and recall aren't symmetric either


Expand Down