Skip to content

Commit

Permalink
MAINT Clean up deprecations for 1.5: in log_loss (#28851)
Browse files Browse the repository at this point in the history
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
  • Loading branch information
jeremiedbb and glemaitre committed Apr 29, 2024
1 parent f9cab76 commit 19c068f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 103 deletions.
59 changes: 16 additions & 43 deletions sklearn/metrics/_classification.py
Expand Up @@ -2816,16 +2816,13 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"eps": [StrOptions({"auto"}), Interval(Real, 0, 1, closed="both")],
"normalize": ["boolean"],
"sample_weight": ["array-like", None],
"labels": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
def log_loss(
y_true, y_pred, *, eps="auto", normalize=True, sample_weight=None, labels=None
):
def log_loss(y_true, y_pred, *, normalize=True, sample_weight=None, labels=None):
r"""Log loss, aka logistic loss or cross-entropy loss.
This is the loss function used in (multinomial) logistic regression
Expand Down Expand Up @@ -2855,19 +2852,8 @@ def log_loss(
ordered alphabetically, as done by
:class:`~sklearn.preprocessing.LabelBinarizer`.
eps : float or "auto", default="auto"
Log loss is undefined for p=0 or p=1, so probabilities are
clipped to `max(eps, min(1 - eps, p))`. The default will depend on the
data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`.
.. versionadded:: 1.2
.. versionchanged:: 1.2
The default value changed from `1e-15` to `"auto"` that is
equivalent to `np.finfo(y_pred.dtype).eps`.
.. deprecated:: 1.3
`eps` is deprecated in 1.3 and will be removed in 1.5.
`y_pred` values are clipped to `[eps, 1-eps]` where `eps` is the machine
precision for `y_pred`'s dtype.
normalize : bool, default=True
If true, return the mean loss per sample.
Expand Down Expand Up @@ -2907,18 +2893,6 @@ def log_loss(
y_pred = check_array(
y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
)
if eps == "auto":
eps = np.finfo(y_pred.dtype).eps
else:
# TODO: Remove user defined eps in 1.5
warnings.warn(
(
"Setting the eps parameter is deprecated and will "
"be removed in 1.5. Instead eps will always have"
"a default value of `np.finfo(y_pred.dtype).eps`."
),
FutureWarning,
)

check_consistent_length(y_pred, y_true, sample_weight)
lb = LabelBinarizer()
Expand Down Expand Up @@ -2949,16 +2923,26 @@ def log_loss(
1 - transformed_labels, transformed_labels, axis=1
)

# Clipping
y_pred = np.clip(y_pred, eps, 1 - eps)

# If y_pred is of single dimension, assume y_true to be binary
# and then check.
if y_pred.ndim == 1:
y_pred = y_pred[:, np.newaxis]
if y_pred.shape[1] == 1:
y_pred = np.append(1 - y_pred, y_pred, axis=1)

eps = np.finfo(y_pred.dtype).eps

# Make sure y_pred is normalized
y_pred_sum = y_pred.sum(axis=1)
if not np.allclose(y_pred_sum, 1, rtol=np.sqrt(eps)):
warnings.warn(
"The y_pred values do not sum to one. Make sure to pass probabilities.",
UserWarning,
)

# Clipping
y_pred = np.clip(y_pred, eps, 1 - eps)

# Check if dimensions are consistent.
transformed_labels = check_array(transformed_labels)
if len(lb.classes_) != y_pred.shape[1]:
Expand All @@ -2979,17 +2963,6 @@ def log_loss(
"labels: {0}".format(lb.classes_)
)

# Renormalize
y_pred_sum = y_pred.sum(axis=1)
if not np.isclose(y_pred_sum, 1, rtol=1e-15, atol=5 * eps).all():
warnings.warn(
(
"The y_pred values do not sum to one. Starting from 1.5 this"
"will result in an error."
),
UserWarning,
)
y_pred = y_pred / y_pred_sum[:, np.newaxis]
loss = -xlogy(transformed_labels, y_pred).sum(axis=1)

return float(_average(loss, weights=sample_weight, normalize=normalize))
Expand Down
104 changes: 47 additions & 57 deletions sklearn/metrics/tests/test_classification.py
Expand Up @@ -2624,62 +2624,37 @@ def test_log_loss():
)
loss = log_loss(y_true, y_pred)
loss_true = -np.mean(bernoulli.logpmf(np.array(y_true) == "yes", y_pred[:, 1]))
assert_almost_equal(loss, loss_true)
assert_allclose(loss, loss_true)

# multiclass case; adapted from http://bit.ly/RJJHWA
y_true = [1, 0, 2]
y_pred = [[0.2, 0.7, 0.1], [0.6, 0.2, 0.2], [0.6, 0.1, 0.3]]
loss = log_loss(y_true, y_pred, normalize=True)
assert_almost_equal(loss, 0.6904911)
assert_allclose(loss, 0.6904911)

# check that we got all the shapes and axes right
# by doubling the length of y_true and y_pred
y_true *= 2
y_pred *= 2
loss = log_loss(y_true, y_pred, normalize=False)
assert_almost_equal(loss, 0.6904911 * 6, decimal=6)

user_warning_msg = "y_pred values do not sum to one"
# check eps and handling of absolute zero and one probabilities
y_pred = np.asarray(y_pred) > 0.5
with pytest.warns(FutureWarning):
loss = log_loss(y_true, y_pred, normalize=True, eps=0.1)
with pytest.warns(UserWarning, match=user_warning_msg):
assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, 0.1, 0.9)))

# binary case: check correct boundary values for eps = 0
with pytest.warns(FutureWarning):
assert log_loss([0, 1], [0, 1], eps=0) == 0
with pytest.warns(FutureWarning):
assert log_loss([0, 1], [0, 0], eps=0) == np.inf
with pytest.warns(FutureWarning):
assert log_loss([0, 1], [1, 1], eps=0) == np.inf

# multiclass case: check correct boundary values for eps = 0
with pytest.warns(FutureWarning):
assert log_loss([0, 1, 2], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], eps=0) == 0
with pytest.warns(FutureWarning):
assert (
log_loss([0, 1, 2], [[0, 0.5, 0.5], [0, 1, 0], [0, 0, 1]], eps=0) == np.inf
)
assert_allclose(loss, 0.6904911 * 6)

# raise error if number of classes are not equal.
y_true = [1, 0, 2]
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]
y_pred = [[0.3, 0.7], [0.6, 0.4], [0.4, 0.6]]
with pytest.raises(ValueError):
log_loss(y_true, y_pred)

# case when y_true is a string array object
y_true = ["ham", "spam", "spam", "ham"]
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]]
with pytest.warns(UserWarning, match=user_warning_msg):
loss = log_loss(y_true, y_pred)
assert_almost_equal(loss, 1.0383217, decimal=6)
y_pred = [[0.3, 0.7], [0.6, 0.4], [0.4, 0.6], [0.7, 0.3]]
loss = log_loss(y_true, y_pred)
assert_allclose(loss, 0.7469410)

# test labels option

y_true = [2, 2]
y_pred = [[0.2, 0.7], [0.6, 0.5]]
y_pred = [[0.2, 0.8], [0.6, 0.4]]
y_score = np.array([[0.1, 0.9], [0.1, 0.9]])
error_str = (
r"y_true contains only one label \(2\). Please provide "
Expand All @@ -2688,50 +2663,66 @@ def test_log_loss():
with pytest.raises(ValueError, match=error_str):
log_loss(y_true, y_pred)

y_pred = [[0.2, 0.7], [0.6, 0.5], [0.2, 0.3]]
error_str = "Found input variables with inconsistent numbers of samples: [3, 2]"
(ValueError, error_str, log_loss, y_true, y_pred)
y_pred = [[0.2, 0.8], [0.6, 0.4], [0.7, 0.3]]
error_str = r"Found input variables with inconsistent numbers of samples: \[3, 2\]"
with pytest.raises(ValueError, match=error_str):
log_loss(y_true, y_pred)

# works when the labels argument is used

true_log_loss = -np.mean(np.log(y_score[:, 1]))
calculated_log_loss = log_loss(y_true, y_score, labels=[1, 2])
assert_almost_equal(calculated_log_loss, true_log_loss)
assert_allclose(calculated_log_loss, true_log_loss)

# ensure labels work when len(np.unique(y_true)) != y_pred.shape[1]
y_true = [1, 2, 2]
y_score2 = [[0.2, 0.7, 0.3], [0.6, 0.5, 0.3], [0.3, 0.9, 0.1]]
with pytest.warns(UserWarning, match=user_warning_msg):
loss = log_loss(y_true, y_score2, labels=[1, 2, 3])
assert_almost_equal(loss, 1.0630345, decimal=6)
y_score2 = [[0.7, 0.1, 0.2], [0.2, 0.7, 0.1], [0.1, 0.7, 0.2]]
loss = log_loss(y_true, y_score2, labels=[1, 2, 3])
assert_allclose(loss, -np.log(0.7))


@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_log_loss_eps(dtype):
"""Check the behaviour internal eps that changes depending on the input dtype.
def test_log_loss_eps_auto(global_dtype):
"""Check the behaviour of `eps="auto"` that changes depending on the input
array dtype.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/24315
"""
y_true = np.array([0, 1], dtype=global_dtype)
y_pred = y_true.copy()
y_true = np.array([0, 1], dtype=dtype)
y_pred = np.array([1, 0], dtype=dtype)

loss = log_loss(y_true, y_pred, eps="auto")
loss = log_loss(y_true, y_pred)
assert np.isfinite(loss)


def test_log_loss_eps_auto_float16():
"""Check the behaviour of `eps="auto"` for np.float16"""
y_true = np.array([0, 1], dtype=np.float16)
y_pred = y_true.copy()
@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.float16])
def test_log_loss_not_probabilities_warning(dtype):
"""Check that log_loss raises a warning when y_pred values don't sum to 1."""
y_true = np.array([0, 1, 1, 0])
y_pred = np.array([[0.2, 0.7], [0.6, 0.3], [0.4, 0.7], [0.8, 0.3]], dtype=dtype)

loss = log_loss(y_true, y_pred, eps="auto")
assert np.isfinite(loss)
with pytest.warns(UserWarning, match="The y_pred values do not sum to one."):
log_loss(y_true, y_pred)


@pytest.mark.parametrize(
"y_true, y_pred",
[
([0, 1, 0], [0, 1, 0]),
([0, 1, 0], [[1, 0], [0, 1], [1, 0]]),
([0, 1, 2], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
],
)
def test_log_loss_perfect_predictions(y_true, y_pred):
"""Check that log_loss returns 0 for perfect predictions."""
# Because of the clipping, the result is not exactly 0
assert log_loss(y_true, y_pred) == pytest.approx(0)


def test_log_loss_pandas_input():
# case when input is a pandas series and dataframe gh-5715
y_tr = np.array(["ham", "spam", "spam", "ham"])
y_pr = np.array([[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]])
y_pr = np.array([[0.3, 0.7], [0.6, 0.4], [0.4, 0.6], [0.7, 0.3]])
types = [(MockDataFrame, MockDataFrame)]
try:
from pandas import DataFrame, Series
Expand All @@ -2742,9 +2733,8 @@ def test_log_loss_pandas_input():
for TrueInputType, PredInputType in types:
# y_pred dataframe, y_true series
y_true, y_pred = TrueInputType(y_tr), PredInputType(y_pr)
with pytest.warns(UserWarning, match="y_pred values do not sum to one"):
loss = log_loss(y_true, y_pred)
assert_almost_equal(loss, 1.0383217, decimal=6)
loss = log_loss(y_true, y_pred)
assert_allclose(loss, 0.7469410)


def test_brier_score_loss():
Expand Down
15 changes: 12 additions & 3 deletions sklearn/metrics/tests/test_common.py
Expand Up @@ -637,7 +637,10 @@ def test_sample_order_invariance_multilabel_and_multioutput():
# Generate some data
y_true = random_state.randint(0, 2, size=(20, 25))
y_pred = random_state.randint(0, 2, size=(20, 25))
y_score = random_state.normal(size=y_true.shape)
y_score = random_state.uniform(size=y_true.shape)

# Some metrics (e.g. log_loss) require y_score to be probabilities (sum to 1)
y_score /= y_score.sum(axis=1, keepdims=True)

y_true_shuffle, y_pred_shuffle, y_score_shuffle = shuffle(
y_true, y_pred, y_score, random_state=0
Expand Down Expand Up @@ -1566,7 +1569,10 @@ def test_multilabel_sample_weight_invariance(name):
)
y_true = np.vstack([ya, yb])
y_pred = np.vstack([ya, ya])
y_score = random_state.randint(1, 4, size=y_true.shape)
y_score = random_state.uniform(size=y_true.shape)

# Some metrics (e.g. log_loss) require y_score to be probabilities (sum to 1)
y_score /= y_score.sum(axis=1, keepdims=True)

metric = ALL_METRICS[name]
if name in THRESHOLDED_METRICS:
Expand Down Expand Up @@ -1629,7 +1635,10 @@ def test_thresholded_multilabel_multioutput_permutations_invariance(name):
random_state = check_random_state(0)
n_samples, n_classes = 20, 4
y_true = random_state.randint(0, 2, size=(n_samples, n_classes))
y_score = random_state.normal(size=y_true.shape)
y_score = random_state.uniform(size=y_true.shape)

# Some metrics (e.g. log_loss) require y_score to be probabilities (sum to 1)
y_score /= y_score.sum(axis=1, keepdims=True)

# Makes sure all samples have at least one label. This works around errors
# when running metrics where average="sample"
Expand Down

0 comments on commit 19c068f

Please sign in to comment.