Skip to content

Commit

Permalink
Use more natural class_weight="auto" heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Mar 5, 2015
1 parent 8dbe3f8 commit 76badb1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
5 changes: 4 additions & 1 deletion doc/whats_new.rst
Expand Up @@ -199,6 +199,9 @@ Enhancements
- The outcome of :func:`manifold.spectral_embedding` was made deterministic
by flipping the sign of eigen vectors. By `Hasil Sharma`_.

- Improved heuristic for ``class_weight="auto"`` for classifiers supporting
``class_weight`` by Hanna Wallach and `Andreas Müller`_


Documentation improvements
..........................
Expand Down Expand Up @@ -323,7 +326,7 @@ Bug fixes
in GMM. By `Alexis Mignon`_.

- Fixed a error in the computation of conditional probabilities in
:class:`naive_bayes.BernoulliNB`. By `Hanna Wallach`_.
:class:`naive_bayes.BernoulliNB`. By Hanna Wallach.

- Make the method ``radius_neighbors`` of
:class:`neighbors.NearestNeighbors` return the samples lying on the
Expand Down
14 changes: 10 additions & 4 deletions sklearn/utils/class_weight.py
Expand Up @@ -15,8 +15,8 @@ def compute_class_weight(class_weight, classes, y):
Parameters
----------
class_weight : dict, 'auto' or None
If 'auto', class weights will be given inverse proportional
to the frequency of the class in the data.
If 'auto', class weights will be given by
``n_samples / (n_classes * np.bincount(y))``.
If a dictionary is given, keys are classes and values
are corresponding class weights.
If None is given, the class weights will be uniform.
Expand All @@ -32,6 +32,11 @@ def compute_class_weight(class_weight, classes, y):
-------
class_weight_vect : ndarray, shape (n_classes,)
Array with class_weight_vect[i] the weight for i-th class
References
----------
The "auto" heuristic is inspired by
Logistic Regression in Rare Events Data, King, Zen, 2001.
"""
# Import error caused by circular imports.
from ..preprocessing import LabelEncoder
Expand All @@ -47,8 +52,9 @@ def compute_class_weight(class_weight, classes, y):
raise ValueError("classes should have valid labels that are in y")

# inversely proportional to the number of samples in the class
recip_freq = 1. / bincount(y_ind)
weight = recip_freq[le.transform(classes)] / np.mean(recip_freq)
recip_freq = len(y) / (len(le.classes_) *
bincount(y_ind).astype(np.float64))
weight = recip_freq[le.transform(classes)]
else:
# user-defined dictionary
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
Expand Down
5 changes: 2 additions & 3 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -905,10 +905,9 @@ def check_class_weight_auto_linear_classifier(name, Classifier):
coef_auto = classifier.fit(X, y).coef_.copy()

# Count each label occurrence to reweight manually
mean_weight = (1. / 3 + 1. / 2) / 2
class_weight = {
1: 1. / 3 / mean_weight,
-1: 1. / 2 / mean_weight,
1: 5. / (2 * 3),
-1: 5. / (2 * 2)
}
classifier.set_params(class_weight=class_weight)
coef_manual = classifier.fit(X, y).coef_.copy()
Expand Down
26 changes: 25 additions & 1 deletion sklearn/utils/tests/test_class_weight.py
@@ -1,5 +1,8 @@
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_blobs

from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils.class_weight import compute_sample_weight

Expand All @@ -26,6 +29,27 @@ def test_compute_class_weight_not_present():
assert_raises(ValueError, compute_class_weight, "auto", classes, y)


def test_compute_class_weight_invariance():
# test that results with class_weight="auto" is invariant against
# class imbalance if the number of samples is identical
X, y = make_blobs(centers=2, random_state=0)
# create dataset where class 1 is duplicated twice
X_1 = np.vstack([X] + [X[y == 1]] * 2)
y_1 = np.hstack([y] + [y[y == 1]] * 2)
# create dataset where class 0 is duplicated twice
X_0 = np.vstack([X] + [X[y == 0]] * 2)
y_0 = np.hstack([y] + [y[y == 0]] * 2)
# cuplicate everything
X_ = np.vstack([X] * 2)
y_ = np.hstack([y] * 2)
# results should be identical
logreg1 = LogisticRegression(class_weight="auto").fit(X_1, y_1)
logreg0 = LogisticRegression(class_weight="auto").fit(X_0, y_0)
logreg = LogisticRegression(class_weight="auto").fit(X_, y_)
assert_array_almost_equal(logreg1.coef_, logreg0.coef_)
assert_array_almost_equal(logreg.coef_, logreg0.coef_)


def test_compute_class_weight_auto_negative():
"""Test compute_class_weight when labels are negative"""
# Test with balanced class labels.
Expand Down Expand Up @@ -116,7 +140,7 @@ def test_compute_sample_weight_with_subsample():
# Test with a bootstrap subsample
y = np.asarray([1, 1, 1, 2, 2, 2])
sample_weight = compute_sample_weight("auto", y, [0, 1, 1, 2, 2, 3])
expected = np.asarray([1/3., 1/3., 1/3., 5/3., 5/3., 5/3.])
expected = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
assert_array_almost_equal(sample_weight, expected)

# Test with a bootstrap subsample for multi-output
Expand Down

0 comments on commit 76badb1

Please sign in to comment.