Skip to content

Commit

Permalink
MAINT Clean up deprecations for 1.5: in NearestCentroid (#28813)
Browse files Browse the repository at this point in the history
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
  • Loading branch information
jeremiedbb and glemaitre committed May 2, 2024
1 parent e19d8c2 commit 4c89b3b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 78 deletions.
71 changes: 13 additions & 58 deletions sklearn/neighbors/_nearest_centroid.py
Expand Up @@ -7,14 +7,11 @@
#
# License: BSD 3 clause

import warnings
from numbers import Real

import numpy as np
from scipy import sparse as sp

from sklearn.metrics.pairwise import _VALID_METRICS

from ..base import BaseEstimator, ClassifierMixin, _fit_context
from ..metrics.pairwise import pairwise_distances_argmin
from ..preprocessing import LabelEncoder
Expand All @@ -34,25 +31,17 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
Parameters
----------
metric : str or callable, default="euclidean"
Metric to use for distance computation. See the documentation of
`scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and
the metrics listed in
:class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric
values. Note that "wminkowski", "seuclidean" and "mahalanobis" are not
supported.
The centroids for the samples corresponding to each class is
the point from which the sum of the distances (according to the metric)
of all samples that belong to that particular class are minimized.
If the `"manhattan"` metric is provided, this centroid is the median
and for all other metrics, the centroid is now set to be the mean.
.. deprecated:: 1.3
Support for metrics other than `euclidean` and `manhattan` and for
callables was deprecated in version 1.3 and will be removed in
version 1.5.
metric : {"euclidean", "manhattan"}, default="euclidean"
Metric to use for distance computation.
If `metric="euclidean"`, the centroid for the samples corresponding to each
class is the arithmetic mean, which minimizes the sum of squared L1 distances.
If `metric="manhattan"`, the centroid is the feature-wise median, which
minimizes the sum of L1 distances.
.. versionchanged:: 1.5
All metrics but `"euclidean"` and `"manhattan"` were deprecated and
now raise an error.
.. versionchanged:: 0.19
`metric='precomputed'` was deprecated and now raises an error
Expand Down Expand Up @@ -108,15 +97,8 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
[1]
"""

_valid_metrics = set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}

_parameter_constraints: dict = {
"metric": [
StrOptions(
_valid_metrics, deprecated=_valid_metrics - {"manhattan", "euclidean"}
),
callable,
],
"metric": [StrOptions({"manhattan", "euclidean"})],
"shrink_threshold": [Interval(Real, 0, None, closed="neither"), None],
}

Expand All @@ -143,19 +125,6 @@ def fit(self, X, y):
self : object
Fitted estimator.
"""
if isinstance(self.metric, str) and self.metric not in (
"manhattan",
"euclidean",
):
warnings.warn(
(
"Support for distance metrics other than euclidean and "
"manhattan and for callables was deprecated in version "
"1.3 and will be removed in version 1.5."
),
FutureWarning,
)

# If X is sparse and the metric is "manhattan", store it in a csc
# format is easier to calculate the median.
if self.metric == "manhattan":
Expand Down Expand Up @@ -195,14 +164,7 @@ def fit(self, X, y):
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
else:
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
else:
# TODO(1.5) remove warning when metric is only manhattan or euclidean
if self.metric != "euclidean":
warnings.warn(
"Averaging for metrics other than "
"euclidean and manhattan not supported. "
"The average is set to be the mean."
)
else: # metric == "euclidean"
self.centroids_[cur_class] = X[center_mask].mean(axis=0)

if self.shrink_threshold:
Expand Down Expand Up @@ -231,7 +193,6 @@ def fit(self, X, y):
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd
return self

# TODO(1.5) remove note about precomputed metric
def predict(self, X):
"""Perform classification on an array of test vectors `X`.
Expand All @@ -246,12 +207,6 @@ def predict(self, X):
-------
C : ndarray of shape (n_samples,)
The predicted classes.
Notes
-----
If the metric constructor parameter is `"precomputed"`, `X` is assumed
to be the distance matrix between the data to be predicted and
`self.centroids_`.
"""
check_is_fitted(self)

Expand Down
22 changes: 2 additions & 20 deletions sklearn/neighbors/tests/test_nearest_centroid.py
Expand Up @@ -56,21 +56,17 @@ def test_classification_toy(csr_container):
assert_array_equal(clf.predict(T_csr.tolil()), true_result)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris():
# Check consistency on dataset iris.
for metric in ("euclidean", "cosine"):
for metric in ("euclidean", "manhattan"):
clf = NearestCentroid(metric=metric).fit(iris.data, iris.target)
score = np.mean(clf.predict(iris.data) == iris.target)
assert score > 0.9, "Failed with score = " + str(score)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris_shrinkage():
# Check consistency on dataset iris, when using shrinkage.
for metric in ("euclidean", "cosine"):
for metric in ("euclidean", "manhattan"):
for shrink_threshold in [None, 0.1, 0.5]:
clf = NearestCentroid(metric=metric, shrink_threshold=shrink_threshold)
clf = clf.fit(iris.data, iris.target)
Expand Down Expand Up @@ -151,20 +147,6 @@ def test_manhattan_metric(csr_container):
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]])


# TODO(1.5): remove this test
@pytest.mark.parametrize(
"metric", sorted(list(NearestCentroid._valid_metrics - {"manhattan", "euclidean"}))
)
def test_deprecated_distance_metric_supports(metric):
# Check that a warning is raised for all deprecated distance metric supports
clf = NearestCentroid(metric=metric)
with pytest.warns(
FutureWarning,
match="Support for distance metrics other than euclidean and manhattan",
):
clf.fit(X, y)


def test_features_zero_var():
# Test that features with 0 variance throw error

Expand Down

0 comments on commit 4c89b3b

Please sign in to comment.