diff --git a/sklearn/neighbors/_nearest_centroid.py b/sklearn/neighbors/_nearest_centroid.py index 75086ee25448e..c9c99aeeaadb2 100644 --- a/sklearn/neighbors/_nearest_centroid.py +++ b/sklearn/neighbors/_nearest_centroid.py @@ -7,14 +7,11 @@ # # License: BSD 3 clause -import warnings from numbers import Real import numpy as np from scipy import sparse as sp -from sklearn.metrics.pairwise import _VALID_METRICS - from ..base import BaseEstimator, ClassifierMixin, _fit_context from ..metrics.pairwise import pairwise_distances_argmin from ..preprocessing import LabelEncoder @@ -34,25 +31,17 @@ class NearestCentroid(ClassifierMixin, BaseEstimator): Parameters ---------- - metric : str or callable, default="euclidean" - Metric to use for distance computation. See the documentation of - `scipy.spatial.distance - `_ and - the metrics listed in - :class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric - values. Note that "wminkowski", "seuclidean" and "mahalanobis" are not - supported. - - The centroids for the samples corresponding to each class is - the point from which the sum of the distances (according to the metric) - of all samples that belong to that particular class are minimized. - If the `"manhattan"` metric is provided, this centroid is the median - and for all other metrics, the centroid is now set to be the mean. - - .. deprecated:: 1.3 - Support for metrics other than `euclidean` and `manhattan` and for - callables was deprecated in version 1.3 and will be removed in - version 1.5. + metric : {"euclidean", "manhattan"}, default="euclidean" + Metric to use for distance computation. + + If `metric="euclidean"`, the centroid for the samples corresponding to each + class is the arithmetic mean, which minimizes the sum of squared L1 distances. + If `metric="manhattan"`, the centroid is the feature-wise median, which + minimizes the sum of L1 distances. + + .. versionchanged:: 1.5 + All metrics but `"euclidean"` and `"manhattan"` were deprecated and + now raise an error. .. versionchanged:: 0.19 `metric='precomputed'` was deprecated and now raises an error @@ -108,15 +97,8 @@ class NearestCentroid(ClassifierMixin, BaseEstimator): [1] """ - _valid_metrics = set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"} - _parameter_constraints: dict = { - "metric": [ - StrOptions( - _valid_metrics, deprecated=_valid_metrics - {"manhattan", "euclidean"} - ), - callable, - ], + "metric": [StrOptions({"manhattan", "euclidean"})], "shrink_threshold": [Interval(Real, 0, None, closed="neither"), None], } @@ -143,19 +125,6 @@ def fit(self, X, y): self : object Fitted estimator. """ - if isinstance(self.metric, str) and self.metric not in ( - "manhattan", - "euclidean", - ): - warnings.warn( - ( - "Support for distance metrics other than euclidean and " - "manhattan and for callables was deprecated in version " - "1.3 and will be removed in version 1.5." - ), - FutureWarning, - ) - # If X is sparse and the metric is "manhattan", store it in a csc # format is easier to calculate the median. if self.metric == "manhattan": @@ -195,14 +164,7 @@ def fit(self, X, y): self.centroids_[cur_class] = np.median(X[center_mask], axis=0) else: self.centroids_[cur_class] = csc_median_axis_0(X[center_mask]) - else: - # TODO(1.5) remove warning when metric is only manhattan or euclidean - if self.metric != "euclidean": - warnings.warn( - "Averaging for metrics other than " - "euclidean and manhattan not supported. " - "The average is set to be the mean." - ) + else: # metric == "euclidean" self.centroids_[cur_class] = X[center_mask].mean(axis=0) if self.shrink_threshold: @@ -231,7 +193,6 @@ def fit(self, X, y): self.centroids_ = dataset_centroid_[np.newaxis, :] + msd return self - # TODO(1.5) remove note about precomputed metric def predict(self, X): """Perform classification on an array of test vectors `X`. @@ -246,12 +207,6 @@ def predict(self, X): ------- C : ndarray of shape (n_samples,) The predicted classes. - - Notes - ----- - If the metric constructor parameter is `"precomputed"`, `X` is assumed - to be the distance matrix between the data to be predicted and - `self.centroids_`. """ check_is_fitted(self) diff --git a/sklearn/neighbors/tests/test_nearest_centroid.py b/sklearn/neighbors/tests/test_nearest_centroid.py index 09c2501818fd3..5ce792ac29d56 100644 --- a/sklearn/neighbors/tests/test_nearest_centroid.py +++ b/sklearn/neighbors/tests/test_nearest_centroid.py @@ -56,21 +56,17 @@ def test_classification_toy(csr_container): assert_array_equal(clf.predict(T_csr.tolil()), true_result) -# TODO(1.5): Remove filterwarnings when support for some metrics is removed -@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn") def test_iris(): # Check consistency on dataset iris. - for metric in ("euclidean", "cosine"): + for metric in ("euclidean", "manhattan"): clf = NearestCentroid(metric=metric).fit(iris.data, iris.target) score = np.mean(clf.predict(iris.data) == iris.target) assert score > 0.9, "Failed with score = " + str(score) -# TODO(1.5): Remove filterwarnings when support for some metrics is removed -@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn") def test_iris_shrinkage(): # Check consistency on dataset iris, when using shrinkage. - for metric in ("euclidean", "cosine"): + for metric in ("euclidean", "manhattan"): for shrink_threshold in [None, 0.1, 0.5]: clf = NearestCentroid(metric=metric, shrink_threshold=shrink_threshold) clf = clf.fit(iris.data, iris.target) @@ -151,20 +147,6 @@ def test_manhattan_metric(csr_container): assert_array_equal(dense_centroid, [[-1, -1], [1, 1]]) -# TODO(1.5): remove this test -@pytest.mark.parametrize( - "metric", sorted(list(NearestCentroid._valid_metrics - {"manhattan", "euclidean"})) -) -def test_deprecated_distance_metric_supports(metric): - # Check that a warning is raised for all deprecated distance metric supports - clf = NearestCentroid(metric=metric) - with pytest.warns( - FutureWarning, - match="Support for distance metrics other than euclidean and manhattan", - ): - clf.fit(X, y) - - def test_features_zero_var(): # Test that features with 0 variance throw error