From dc580a8ef5ee2a8aea80498388690e2213118efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Thu, 8 Dec 2022 13:04:29 +0100 Subject: [PATCH] Release 1.2.0 [cd build] (#25121) --- build_tools/azure/install_win.sh | 4 +- doc/computing/parallelism.rst | 7 ++ doc/developers/maintainer.rst | 1 + doc/whats_new/v1.2.rst | 24 ++++- .../plot_release_highlights_1_2_0.py | 38 +++++++- setup.cfg | 2 +- setup.py | 23 ++++- sklearn/__init__.py | 2 +- sklearn/decomposition/_nmf.py | 23 +---- sklearn/ensemble/_gb.py | 7 +- .../gradient_boosting.py | 15 ++-- .../tests/test_gradient_boosting.py | 26 ++++-- .../ensemble/tests/test_gradient_boosting.py | 7 +- sklearn/inspection/_plot/decision_boundary.py | 12 ++- .../tests/test_boundary_decision_display.py | 10 +++ .../tests/test_plot_partial_dependence.py | 2 +- .../_argkmin.pyx.tp | 7 +- .../_radius_neighbors.pyx.tp | 7 +- .../test_pairwise_distances_reduction.py | 76 ++++++++++++++-- .../neural_network/_multilayer_perceptron.py | 30 ++++++- sklearn/neural_network/tests/test_mlp.py | 41 +++++++-- sklearn/tests/test_public_functions.py | 87 +++++++++++++++---- sklearn/utils/_param_validation.py | 33 +++++-- sklearn/utils/estimator_checks.py | 5 +- sklearn/utils/extmath.py | 8 +- sklearn/utils/tests/test_extmath.py | 24 +++++ sklearn/utils/tests/test_param_validation.py | 83 ++++++++++++------ sklearn/utils/tests/test_validation.py | 21 +++++ sklearn/utils/validation.py | 7 ++ 29 files changed, 497 insertions(+), 135 deletions(-) diff --git a/build_tools/azure/install_win.sh b/build_tools/azure/install_win.sh index b28bc86270925..ab559a1878971 100755 --- a/build_tools/azure/install_win.sh +++ b/build_tools/azure/install_win.sh @@ -7,9 +7,7 @@ set -x source build_tools/shared.sh if [[ "$DISTRIB" == "conda" ]]; then - conda update -n base conda -y - conda install pip -y - pip install "$(get_dep conda-lock min)" + conda install -c conda-forge "$(get_dep conda-lock min)" -y conda-lock install --name $VIRTUALENV $LOCK_FILE source activate $VIRTUALENV else diff --git a/doc/computing/parallelism.rst b/doc/computing/parallelism.rst index 97e3e2866083f..b7add493a88b1 100644 --- a/doc/computing/parallelism.rst +++ b/doc/computing/parallelism.rst @@ -299,6 +299,13 @@ When this environment variable is set to a non zero value, the `Cython` derivative, `boundscheck` is set to `True`. This is useful for finding segfaults. +`SKLEARN_BUILD_ENABLE_DEBUG_SYMBOLS` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When this environment variable is set to a non zero value, the debug symbols +will be included in the compiled C extensions. Only debug symbols for POSIX +systems is configured. + `SKLEARN_PAIRWISE_DIST_CHUNK_SIZE` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/developers/maintainer.rst b/doc/developers/maintainer.rst index 41fd571ae0389..b769851736b54 100644 --- a/doc/developers/maintainer.rst +++ b/doc/developers/maintainer.rst @@ -310,6 +310,7 @@ The following GitHub checklist might be helpful in a release PR:: * [ ] upload the wheels and source tarball to PyPI * [ ] https://github.com/scikit-learn/scikit-learn/releases publish (except for RC) * [ ] announce on mailing list and on Twitter, and LinkedIn + * [ ] update SECURITY.md in main branch (except for RC) Merging Pull Requests --------------------- diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6a63033878f84..90f25c9a586a5 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -411,11 +411,15 @@ Changelog :mod:`sklearn.inspection` ......................... -- |Enhancement| Extended :func:`inspection.partial_dependence` and +- |MajorFeature| Extended :func:`inspection.partial_dependence` and :class:`inspection.PartialDependenceDisplay` to handle categorical features. :pr:`18298` by :user:`Madhura Jayaratne ` and :user:`Guillaume Lemaitre `. +- |Fix| :class:`inspection.DecisionBoundaryDisplay` now raises error if input + data is not 2-dimensional. + :pr:`25077` by :user:`Arturo Amor `. + :mod:`sklearn.kernel_approximation` ................................... @@ -641,6 +645,16 @@ Changelog dtype for `numpy.float32` inputs. :pr:`22665` by :user:`Julien Jerphanion `. +:mod:`sklearn.neural_network` +............................. + +- |Fix| :class:`neural_network.MLPClassifier` and + :class:`neural_network.MLPRegressor` always expose the parameters `best_loss_`, + `validation_scores_`, and `best_validation_score_`. `best_loss_` is set to + `None` when `early_stopping=True`, while `validation_scores_` and + `best_validation_score_` are set to `None` when `early_stopping=False`. + :pr:`24683` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.pipeline` ....................... @@ -696,6 +710,10 @@ Changelog - |Enhancement| :func:`utils.validation.column_or_1d` now accepts a `dtype` parameter to specific `y`'s dtype. :pr:`22629` by `Thomas Fan`_. +- |Enhancement| :func:`utils.extmath.cartesian` now accepts arrays with different + `dtype` and will cast the ouptut to the most permissive `dtype`. + :pr:`25067` by :user:`Guillaume Lemaitre `. + - |Fix| :func:`utils.multiclass.type_of_target` now properly handles sparse matrices. :pr:`14862` by :user:`LĂ©onard Binet `. @@ -705,6 +723,10 @@ Changelog - |Fix| :func:`utils.estimator_checks.check_estimator` now takes into account the `requires_positive_X` tag correctly. :pr:`24667` by `Thomas Fan`_. +- |Fix| :func:`utils.check_array` now supports Pandas Series with `pd.NA` + by raising a better error message or returning a compatible `ndarray`. + :pr:`25080` by `Thomas Fan`_. + - |API| The extra keyword parameters of :func:`utils.extmath.density` are deprecated and will be removed in 1.4. :pr:`24523` by :user:`Mia Bajic `. diff --git a/examples/release_highlights/plot_release_highlights_1_2_0.py b/examples/release_highlights/plot_release_highlights_1_2_0.py index 32b1108caa920..8165c3bc4eed0 100644 --- a/examples/release_highlights/plot_release_highlights_1_2_0.py +++ b/examples/release_highlights/plot_release_highlights_1_2_0.py @@ -93,6 +93,42 @@ hist_no_interact, X, y, cv=5, n_jobs=2, train_sizes=np.linspace(0.1, 1, 5) ) +# %% +# :class:`~inspection.PartialDependenceDisplay` exposes a new parameter +# `categorical_features` to display partial dependence for categorical features +# using bar plots and heatmaps. +from sklearn.datasets import fetch_openml + +X, y = fetch_openml( + "titanic", version=1, as_frame=True, return_X_y=True, parser="pandas" +) +X = X.select_dtypes(["number", "category"]).drop(columns=["body"]) + +# %% +from sklearn.preprocessing import OrdinalEncoder +from sklearn.pipeline import make_pipeline + +categorical_features = ["pclass", "sex", "embarked"] +model = make_pipeline( + ColumnTransformer( + transformers=[("cat", OrdinalEncoder(), categorical_features)], + remainder="passthrough", + ), + HistGradientBoostingRegressor(random_state=0), +).fit(X, y) + +# %% +from sklearn.inspection import PartialDependenceDisplay + +fig, ax = plt.subplots(figsize=(14, 4), constrained_layout=True) +_ = PartialDependenceDisplay.from_estimator( + model, + X, + features=["age", "sex", ("pclass", "sex")], + categorical_features=categorical_features, + ax=ax, +) + # %% # Faster parser in :func:`~datasets.fetch_openml` # ----------------------------------------------- @@ -100,8 +136,6 @@ # more memory and CPU efficient. In v1.4, the default will change to # `parser="auto"` which will automatically use the `"pandas"` parser for dense # data and `"liac-arff"` for sparse data. -from sklearn.datasets import fetch_openml - X, y = fetch_openml( "titanic", version=1, as_frame=True, return_X_y=True, parser="pandas" ) diff --git a/setup.cfg b/setup.cfg index 6976ebb2a3819..081e78c92d480 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [options] -packages = find_namespace: +packages = find: [options.packages.find] include = sklearn* diff --git a/setup.py b/setup.py index 27773c8a57faa..10756972dcafd 100755 --- a/setup.py +++ b/setup.py @@ -497,8 +497,24 @@ def configure_extension_modules(): is_pypy = platform.python_implementation() == "PyPy" np_include = numpy.get_include() - default_libraries = ["m"] if os.name == "posix" else [] - default_extra_compile_args = ["-O3"] + + optimization_level = "O2" + if os.name == "posix": + default_extra_compile_args = [f"-{optimization_level}"] + default_libraries = ["m"] + else: + default_extra_compile_args = [f"/{optimization_level}"] + default_libraries = [] + + build_with_debug_symbols = ( + os.environ.get("SKLEARN_BUILD_ENABLE_DEBUG_SYMBOLS", "0") != "0" + ) + if os.name == "posix": + if build_with_debug_symbols: + default_extra_compile_args.append("-g") + else: + # Setting -g0 will strip symbols, reducing the binary size of extensions + default_extra_compile_args.append("-g0") cython_exts = [] for submodule, extensions in extension_config.items(): @@ -608,9 +624,8 @@ def setup_package(): cmdclass=cmdclass, python_requires=python_requires, install_requires=min_deps.tag_to_packages["install"], - package_data={"": ["*.pxd"]}, + package_data={"": ["*.csv", "*.gz", "*.txt", "*.pxd", "*.rst", "*.jpg"]}, zip_safe=False, # the package can run out of an .egg file - include_package_data=True, extras_require={ key: min_deps.tag_to_packages[key] for key in ["examples", "docs", "tests", "benchmark"] diff --git a/sklearn/__init__.py b/sklearn/__init__.py index f9e62d3ad969a..f7cbf20b178e9 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -39,7 +39,7 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = "1.2.0rc1" +__version__ = "1.2.0" # On OSX, we can get a runtime error due to multiple OpenMP libraries loaded diff --git a/sklearn/decomposition/_nmf.py b/sklearn/decomposition/_nmf.py index 16b6efca955d1..5243946a93e44 100644 --- a/sklearn/decomposition/_nmf.py +++ b/sklearn/decomposition/_nmf.py @@ -890,25 +890,7 @@ def _fit_multiplicative_update( "X": ["array-like", "sparse matrix"], "W": ["array-like", None], "H": ["array-like", None], - "n_components": [Interval(Integral, 1, None, closed="left"), None], - "init": [ - StrOptions({"random", "nndsvd", "nndsvda", "nndsvdar", "custom"}), - None, - ], "update_H": ["boolean"], - "solver": [StrOptions({"mu", "cd"})], - "beta_loss": [ - StrOptions({"frobenius", "kullback-leibler", "itakura-saito"}), - Real, - ], - "tol": [Interval(Real, 0, None, closed="left")], - "max_iter": [Interval(Integral, 1, None, closed="left")], - "alpha_W": [Interval(Real, 0, None, closed="left")], - "alpha_H": [Interval(Real, 0, None, closed="left"), StrOptions({"same"})], - "l1_ratio": [Interval(Real, 0, 1, closed="both")], - "random_state": ["random_state"], - "verbose": ["verbose"], - "shuffle": ["boolean"], } ) def non_negative_factorization( @@ -1107,8 +1089,6 @@ def non_negative_factorization( >>> W, H, n_iter = non_negative_factorization( ... X, n_components=2, init='random', random_state=0) """ - X = check_array(X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32]) - est = NMF( n_components=n_components, init=init, @@ -1123,6 +1103,9 @@ def non_negative_factorization( verbose=verbose, shuffle=shuffle, ) + est._validate_params() + + X = check_array(X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32]) with config_context(assume_finite=True): W, H, n_iter = est._fit_transform(X, W=W, H=H, update_H=update_H) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index f6af1150203e5..ab123076ee72e 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -488,8 +488,11 @@ def fit(self, X, y, sample_weight=None, monitor=None): try: self.init_.fit(X, y, sample_weight=sample_weight) except TypeError as e: - # regular estimator without SW support - raise ValueError(msg) from e + if "unexpected keyword argument 'sample_weight'" in str(e): + # regular estimator without SW support + raise ValueError(msg) from e + else: # regular estimator whose input checking failed + raise except ValueError as e: if ( "pass parameters to specific steps of " diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index af9225933100c..38f021ec5f82d 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -270,18 +270,21 @@ def _check_categories(self, X): if missing.any(): categories = categories[~missing] + if hasattr(self, "feature_names_in_"): + feature_name = f"'{self.feature_names_in_[f_idx]}'" + else: + feature_name = f"at index {f_idx}" + if categories.size > self.max_bins: raise ValueError( - f"Categorical feature at index {f_idx} is " - "expected to have a " - f"cardinality <= {self.max_bins}" + f"Categorical feature {feature_name} is expected to " + f"have a cardinality <= {self.max_bins}" ) if (categories >= self.max_bins).any(): raise ValueError( - f"Categorical feature at index {f_idx} is " - "expected to be encoded with " - f"values < {self.max_bins}" + f"Categorical feature {feature_name} is expected to " + f"be encoded with values < {self.max_bins}" ) else: categories = None diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index d1a8f56bbd479..25d245b52eaf7 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -58,10 +58,6 @@ def _make_dumb_dataset(n_samples): @pytest.mark.parametrize( "params, err_msg", [ - ( - {"interaction_cst": "string"}, - "", - ), ( {"interaction_cst": [0, 1]}, "Interaction constraints must be a sequence of tuples or lists", @@ -1141,20 +1137,32 @@ def test_categorical_spec_no_categories(Est, categorical_features, as_array): @pytest.mark.parametrize( "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor) ) -def test_categorical_bad_encoding_errors(Est): +@pytest.mark.parametrize( + "use_pandas, feature_name", [(False, "at index 0"), (True, "'f0'")] +) +def test_categorical_bad_encoding_errors(Est, use_pandas, feature_name): # Test errors when categories are encoded incorrectly gb = Est(categorical_features=[True], max_bins=2) - X = np.array([[0, 1, 2]]).T + if use_pandas: + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"f0": [0, 1, 2]}) + else: + X = np.array([[0, 1, 2]]).T y = np.arange(3) - msg = "Categorical feature at index 0 is expected to have a cardinality <= 2" + msg = f"Categorical feature {feature_name} is expected to have a cardinality <= 2" with pytest.raises(ValueError, match=msg): gb.fit(X, y) - X = np.array([[0, 2]]).T + if use_pandas: + X = pd.DataFrame({"f0": [0, 2]}) + else: + X = np.array([[0, 2]]).T y = np.arange(2) - msg = "Categorical feature at index 0 is expected to be encoded with values < 2" + msg = ( + f"Categorical feature {feature_name} is expected to be encoded with values < 2" + ) with pytest.raises(ValueError, match=msg): gb.fit(X, y) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 4c355332b1b81..4e90f5ce54e67 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -27,6 +27,7 @@ from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import skip_if_32bit +from sklearn.utils._param_validation import InvalidParameterError from sklearn.exceptions import DataConversionWarning from sklearn.exceptions import NotFittedError from sklearn.dummy import DummyClassifier, DummyRegressor @@ -1265,14 +1266,14 @@ def test_gradient_boosting_with_init_pipeline(): # Passing sample_weight to a pipeline raises a ValueError. This test makes # sure we make the distinction between ValueError raised by a pipeline that - # was passed sample_weight, and a ValueError raised by a regular estimator - # whose input checking failed. + # was passed sample_weight, and a InvalidParameterError raised by a regular + # estimator whose input checking failed. invalid_nu = 1.5 err_msg = ( "The 'nu' parameter of NuSVR must be a float in the" f" range (0.0, 1.0]. Got {invalid_nu} instead." ) - with pytest.raises(ValueError, match=re.escape(err_msg)): + with pytest.raises(InvalidParameterError, match=re.escape(err_msg)): # Note that NuSVR properly supports sample_weight init = NuSVR(gamma="auto", nu=invalid_nu) gb = GradientBoostingRegressor(init=init) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 86836a81f7207..22b4590d9bc3c 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -6,7 +6,11 @@ from ...utils import check_matplotlib_support from ...utils import _safe_indexing from ...base import is_regressor -from ...utils.validation import check_is_fitted, _is_arraylike_not_scalar +from ...utils.validation import ( + check_is_fitted, + _is_arraylike_not_scalar, + _num_features, +) def _check_boundary_response_method(estimator, response_method): @@ -316,6 +320,12 @@ def from_estimator( f"Got {plot_method} instead." ) + num_features = _num_features(X) + if num_features != 2: + raise ValueError( + f"n_features must be equal to 2. Got {num_features} instead." + ) + x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1) x0_min, x0_max = x0.min() - eps, x0.max() + eps diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 8981c9d5a5e83..97b1b98e3db93 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -38,6 +38,16 @@ def fitted_clf(): return LogisticRegression().fit(X, y) +def test_input_data_dimension(pyplot): + """Check that we raise an error when `X` does not have exactly 2 features.""" + X, y = make_classification(n_samples=10, n_features=4, random_state=0) + + clf = LogisticRegression().fit(X, y) + msg = "n_features must be equal to 2. Got 4 instead." + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X) + + def test_check_boundary_response_method_auto(): """Check _check_boundary_response_method behavior with 'auto'.""" diff --git a/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py b/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py index f48c579d04528..329485ba918d6 100644 --- a/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py +++ b/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py @@ -676,7 +676,7 @@ def test_plot_partial_dependence_does_not_override_ylabel( def test_plot_partial_dependence_with_categorical( pyplot, categorical_features, array_type ): - X = [["A", 1, "A"], ["B", 0, "C"], ["C", 2, "B"]] + X = [[1, 1, "A"], [2, 0, "C"], [3, 2, "B"]] column_name = ["col_A", "col_B", "col_C"] X = _convert_container(X, array_type, columns_name=column_name) y = np.array([1.2, 0.5, 0.45]).T diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp index eec2e2aabdd06..b8afe5c3cd5f8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp @@ -330,11 +330,8 @@ cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}): metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp index 0fdc3bb50203f..b3f20cac3ea08 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp @@ -336,11 +336,8 @@ cdef class EuclideanRadiusNeighbors{{name_suffix}}(RadiusNeighbors{{name_suffix} metric_kwargs=None, ): if ( - metric_kwargs is not None and - len(metric_kwargs) > 0 and ( - "Y_norm_squared" not in metric_kwargs or - "X_norm_squared" not in metric_kwargs - ) + isinstance(metric_kwargs, dict) and + (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"}) ): warnings.warn( f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't " diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index c334087c65448..4fe8013cd3602 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1,5 +1,6 @@ import itertools import re +import warnings from collections import defaultdict import numpy as np @@ -620,19 +621,44 @@ def test_argkmin_factory_method_wrong_usages(): with pytest.raises(ValueError, match="ndarray is not C-contiguous"): ArgKmin.compute(X=np.asfortranarray(X), Y=Y, k=k, metric=metric) + # A UserWarning must be raised in this case. unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(" - r"EuclideanArgKmin64." - ) + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): ArgKmin.compute( X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + ArgKmin.compute(X=X, Y=Y, k=k, metric=metric, metric_kwargs=metric_kwargs) + def test_radius_neighbors_factory_method_wrong_usages(): rng = np.random.RandomState(1) @@ -683,16 +709,48 @@ def test_radius_neighbors_factory_method_wrong_usages(): unused_metric_kwargs = {"p": 3} - message = ( - r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this" - r" case \(EuclideanRadiusNeighbors64" - ) + # A UserWarning must be raised in this case. + message = r"Some metric_kwargs have been passed \({'p': 3}\) but" with pytest.warns(UserWarning, match=message): RadiusNeighbors.compute( X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs ) + # A UserWarning must be raised in this case. + metric_kwargs = { + "p": 3, # unused + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + + message = r"Some metric_kwargs have been passed \({'p': 3, 'Y_norm_squared'" + + with pytest.warns(UserWarning, match=message): + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + "Y_norm_squared": sqeuclidean_row_norms(Y, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + + # No user warning must be raised in this case. + metric_kwargs = { + "X_norm_squared": sqeuclidean_row_norms(X, num_threads=2), + } + with warnings.catch_warnings(): + warnings.simplefilter("error", category=UserWarning) + RadiusNeighbors.compute( + X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=metric_kwargs + ) + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index 8e0dfdbbcade2..082c0200871cd 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -391,8 +391,11 @@ def _initialize(self, y, layer_units, dtype): if self.early_stopping: self.validation_scores_ = [] self.best_validation_score_ = -np.inf + self.best_loss_ = None else: self.best_loss_ = np.inf + self.validation_scores_ = None + self.best_validation_score_ = None def _init_coef(self, fan_in, fan_out, dtype): # Use the initialization method recommended by @@ -686,6 +689,7 @@ def _fit_stochastic( # restore best weights self.coefs_ = self._best_coefs self.intercepts_ = self._best_intercepts + self.validation_scores_ = self.validation_scores_ def _update_no_improvement_count(self, early_stopping, X_val, y_val): if early_stopping: @@ -919,12 +923,24 @@ class MLPClassifier(ClassifierMixin, BaseMultilayerPerceptron): loss_ : float The current loss computed with the loss function. - best_loss_ : float + best_loss_ : float or None The minimum loss reached by the solver throughout fitting. + If `early_stopping=True`, this attribute is set ot `None`. Refer to + the `best_validation_score_` fitted attribute instead. loss_curve_ : list of shape (`n_iter_`,) The ith element in the list represents the loss at the ith iteration. + validation_scores_ : list of shape (`n_iter_`,) or None + The score at each iteration on a held-out validation set. The score + reported is the accuracy score. Only available if `early_stopping=True`, + otherwise the attribute is set to `None`. + + best_validation_score_ : float or None + The best validation score (i.e. accuracy score) that triggered the + early stopping. Only available if `early_stopping=True`, otherwise the + attribute is set to `None`. + t_ : int The number of training samples seen by the solver during fitting. @@ -1388,11 +1404,23 @@ class MLPRegressor(RegressorMixin, BaseMultilayerPerceptron): best_loss_ : float The minimum loss reached by the solver throughout fitting. + If `early_stopping=True`, this attribute is set ot `None`. Refer to + the `best_validation_score_` fitted attribute instead. loss_curve_ : list of shape (`n_iter_`,) Loss value evaluated at the end of each training step. The ith element in the list represents the loss at the ith iteration. + validation_scores_ : list of shape (`n_iter_`,) or None + The score at each iteration on a held-out validation set. The score + reported is the R2 score. Only available if `early_stopping=True`, + otherwise the attribute is set to `None`. + + best_validation_score_ : float or None + The best validation score (i.e. R2 score) that triggered the + early stopping. Only available if `early_stopping=True`, otherwise the + attribute is set to `None`. + t_ : int The number of training samples seen by the solver during fitting. Mathematically equals `n_iters * X.shape[0]`, it means diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 4dda507a90381..94612130419f7 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -668,20 +668,36 @@ def test_verbose_sgd(): assert "Iteration" in output.getvalue() -def test_early_stopping(): +@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor]) +def test_early_stopping(MLPEstimator): X = X_digits_binary[:100] y = y_digits_binary[:100] tol = 0.2 - clf = MLPClassifier(tol=tol, max_iter=3000, solver="sgd", early_stopping=True) - clf.fit(X, y) - assert clf.max_iter > clf.n_iter_ + mlp_estimator = MLPEstimator( + tol=tol, max_iter=3000, solver="sgd", early_stopping=True + ) + mlp_estimator.fit(X, y) + assert mlp_estimator.max_iter > mlp_estimator.n_iter_ + + assert mlp_estimator.best_loss_ is None + assert isinstance(mlp_estimator.validation_scores_, list) - valid_scores = clf.validation_scores_ - best_valid_score = clf.best_validation_score_ + valid_scores = mlp_estimator.validation_scores_ + best_valid_score = mlp_estimator.best_validation_score_ assert max(valid_scores) == best_valid_score assert best_valid_score + tol > valid_scores[-2] assert best_valid_score + tol > valid_scores[-1] + # check that the attributes `validation_scores_` and `best_validation_score_` + # are set to None when `early_stopping=False` + mlp_estimator = MLPEstimator( + tol=tol, max_iter=3000, solver="sgd", early_stopping=False + ) + mlp_estimator.fit(X, y) + assert mlp_estimator.validation_scores_ is None + assert mlp_estimator.best_validation_score_ is None + assert mlp_estimator.best_loss_ is not None + def test_adaptive_learning_rate(): X = [[3, 2], [1, 6]] @@ -876,3 +892,16 @@ def test_mlp_loading_from_joblib_partial_fit(tmp_path): # finetuned model learned the new target predicted_value = load_estimator.predict(fine_tune_features) assert_allclose(predicted_value, fine_tune_target, rtol=1e-4) + + +@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor]) +def test_mlp_warm_start_with_early_stopping(MLPEstimator): + """Check that early stopping works with warm start.""" + mlp = MLPEstimator( + max_iter=10, random_state=0, warm_start=True, early_stopping=True + ) + mlp.fit(X_iris, y_iris) + n_validation_scores = len(mlp.validation_scores_) + mlp.set_params(max_iter=20) + mlp.fit(X_iris, y_iris) + assert len(mlp.validation_scores_) > n_validation_scores diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 85cd0638a5ef3..9f500cadd959b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -6,20 +6,10 @@ from sklearn.utils._param_validation import generate_invalid_param_val from sklearn.utils._param_validation import generate_valid_param from sklearn.utils._param_validation import make_constraint +from sklearn.utils._param_validation import InvalidParameterError -PARAM_VALIDATION_FUNCTION_LIST = [ - "sklearn.cluster.kmeans_plusplus", - "sklearn.svm.l1_min_c", - "sklearn.metrics.accuracy_score", -] - - -@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST) -def test_function_param_validation(func_module): - """Check that an informative error is raised when the value of a parameter does not - have an appropriate type or value. - """ +def _get_func_info(func_module): module_name, func_name = func_module.rsplit(".", 1) module = import_module(module_name) func = getattr(module, func_name) @@ -30,12 +20,23 @@ def test_function_param_validation(func_module): for p in func_sig.parameters.values() if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD) ] - parameter_constraints = getattr(func, "_skl_parameter_constraints") - # generate valid values for the required parameters + # The parameters `*args` and `**kwargs` are ignored since we cannot generate + # constraints. required_params = [ p.name for p in func_sig.parameters.values() if p.default is p.empty ] + + return func, func_name, func_params, required_params + + +def _check_function_param_validation( + func, func_name, func_params, required_params, parameter_constraints +): + """Check that an informative error is raised when the value of a parameter does not + have an appropriate type or value. + """ + # generate valid values for the required parameters valid_required_params = {} for param_name in required_params: if parameter_constraints[param_name] == "no_validation": @@ -72,7 +73,7 @@ def test_function_param_validation(func_module): ) # First, check that the error is raised if param doesn't match any valid type. - with pytest.raises(ValueError, match=match): + with pytest.raises(InvalidParameterError, match=match): func(**{**valid_required_params, param_name: param_with_bad_type}) # Then, for constraints that are more than a type constraint, check that the @@ -86,5 +87,59 @@ def test_function_param_validation(func_module): except NotImplementedError: continue - with pytest.raises(ValueError, match=match): + with pytest.raises(InvalidParameterError, match=match): func(**{**valid_required_params, param_name: bad_value}) + + +PARAM_VALIDATION_FUNCTION_LIST = [ + "sklearn.cluster.kmeans_plusplus", + "sklearn.metrics.accuracy_score", + "sklearn.svm.l1_min_c", +] + + +@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST) +def test_function_param_validation(func_module): + """Check param validation for public functions that are not wrappers around + estimators. + """ + func, func_name, func_params, required_params = _get_func_info(func_module) + + parameter_constraints = getattr(func, "_skl_parameter_constraints") + + _check_function_param_validation( + func, func_name, func_params, required_params, parameter_constraints + ) + + +PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ + ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"), +] + + +@pytest.mark.parametrize( + "func_module, class_module", PARAM_VALIDATION_CLASS_WRAPPER_LIST +) +def test_class_wrapper_param_validation(func_module, class_module): + """Check param validation for public functions that are wrappers around + estimators. + """ + func, func_name, func_params, required_params = _get_func_info(func_module) + + module_name, class_name = class_module.rsplit(".", 1) + module = import_module(module_name) + klass = getattr(module, class_name) + + parameter_constraints_func = getattr(func, "_skl_parameter_constraints") + parameter_constraints_class = getattr(klass, "_parameter_constraints") + parameter_constraints = { + **parameter_constraints_class, + **parameter_constraints_func, + } + parameter_constraints = { + k: v for k, v in parameter_constraints.items() if k in func_params + } + + _check_function_param_validation( + func, func_name, func_params, required_params, parameter_constraints + ) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 797063a31dd96..9fa51d465d0a7 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -7,6 +7,7 @@ from numbers import Integral from numbers import Real import operator +import re import warnings import numpy as np @@ -16,6 +17,14 @@ from .validation import _is_arraylike_not_scalar +class InvalidParameterError(ValueError, TypeError): + """Custom exception to be raised when the parameter of a class/method/function + does not have a valid type or value. + """ + + # Inherits from ValueError and TypeError to keep backward compatibility. + + def validate_parameter_constraints(parameter_constraints, params, caller_name): """Validate types and values of given parameters. @@ -50,13 +59,6 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): caller_name : str The name of the estimator or function or method that called this function. """ - if len(set(parameter_constraints) - set(params)) != 0: - raise ValueError( - f"The parameter constraints {list(parameter_constraints)}" - " contain unexpected parameters" - f" {set(parameter_constraints) - set(params)}" - ) - for param_name, param_val in params.items(): # We allow parameters to not have a constraint so that third party estimators # can inherit from sklearn estimators without having to necessarily use the @@ -92,7 +94,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): f" {constraints[-1]}" ) - raise ValueError( + raise InvalidParameterError( f"The {param_name!r} parameter of {caller_name} must be" f" {constraints_str}. Got {param_val!r} instead." ) @@ -185,7 +187,20 @@ def wrapper(*args, **kwargs): validate_parameter_constraints( parameter_constraints, params, caller_name=func.__qualname__ ) - return func(*args, **kwargs) + + try: + return func(*args, **kwargs) + except InvalidParameterError as e: + # When the function is just a wrapper around an estimator, we allow + # the function to delegate validation to the estimator, but we replace + # the name of the estimator by the name of the function in the error + # message to avoid confusion. + msg = re.sub( + r"parameter of \w+ must be", + f"parameter of {func.__qualname__} must be", + str(e), + ) + raise InvalidParameterError(msg) from e return wrapper diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index b080591714b37..fba40e4a266c8 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -61,6 +61,7 @@ from ..utils.validation import check_is_fitted from ..utils._param_validation import make_constraint from ..utils._param_validation import generate_invalid_param_val +from ..utils._param_validation import InvalidParameterError from . import shuffle from ._tags import ( @@ -4082,7 +4083,7 @@ def check_param_validation(name, estimator_orig): # the method is not accessible with the current set of parameters continue - with raises(ValueError, match=match, err_msg=err_msg): + with raises(InvalidParameterError, match=match, err_msg=err_msg): if any( isinstance(X_type, str) and X_type.endswith("labels") for X_type in _safe_tags(estimator, key="X_types") @@ -4110,7 +4111,7 @@ def check_param_validation(name, estimator_orig): # the method is not accessible with the current set of parameters continue - with raises(ValueError, match=match, err_msg=err_msg): + with raises(InvalidParameterError, match=match, err_msg=err_msg): if any( X_type.endswith("labels") for X_type in _safe_tags(estimator, key="X_types") diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 85063d5888cc0..02e65704274c1 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -718,6 +718,12 @@ def cartesian(arrays, out=None): ------- out : ndarray of shape (M, len(arrays)) Array containing the cartesian products formed of input arrays. + If not provided, the `dtype` of the output array is set to the most + permissive `dtype` of the input arrays, according to NumPy type + promotion. + + .. versionadded:: 1.2 + Add support for arrays of different types. Notes ----- @@ -743,12 +749,12 @@ def cartesian(arrays, out=None): """ arrays = [np.asarray(x) for x in arrays] shape = (len(x) for x in arrays) - dtype = arrays[0].dtype ix = np.indices(shape) ix = ix.reshape(len(arrays), -1).T if out is None: + dtype = np.result_type(*arrays) # find the most permissive dtype out = np.empty_like(ix, dtype=dtype) for n, arr in enumerate(arrays): diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index d3626a1efbe0b..84285356c0897 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -637,6 +637,30 @@ def test_cartesian(): assert_array_equal(x[:, np.newaxis], cartesian((x,))) +@pytest.mark.parametrize( + "arrays, output_dtype", + [ + ( + [np.array([1, 2, 3], dtype=np.int32), np.array([4, 5], dtype=np.int64)], + np.dtype(np.int64), + ), + ( + [np.array([1, 2, 3], dtype=np.int32), np.array([4, 5], dtype=np.float64)], + np.dtype(np.float64), + ), + ( + [np.array([1, 2, 3], dtype=np.int32), np.array(["x", "y"], dtype=object)], + np.dtype(object), + ), + ], +) +def test_cartesian_mix_types(arrays, output_dtype): + """Check that the cartesian product works with mixed types.""" + output = cartesian(arrays) + + assert output.dtype == output_dtype + + def test_logistic_sigmoid(): # Check correctness and robustness of logistic sigmoid implementation def naive_log_logistic(x): diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index fd73797582631..85cd06d0f38b8 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -28,6 +28,7 @@ from sklearn.utils._param_validation import generate_invalid_param_val from sklearn.utils._param_validation import generate_valid_param from sklearn.utils._param_validation import validate_params +from sklearn.utils._param_validation import InvalidParameterError # Some helpers for the tests @@ -433,40 +434,38 @@ def test_make_constraint_unknown(): def test_validate_params(): """Check that validate_params works no matter how the arguments are passed""" - with pytest.raises(ValueError, match="The 'a' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'a' parameter of _func must be" + ): _func("wrong", c=1) - with pytest.raises(ValueError, match="The 'b' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'b' parameter of _func must be" + ): _func(*[1, "wrong"], c=1) - with pytest.raises(ValueError, match="The 'c' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'c' parameter of _func must be" + ): _func(1, **{"c": "wrong"}) - with pytest.raises(ValueError, match="The 'd' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'd' parameter of _func must be" + ): _func(1, c=1, d="wrong") # check in the presence of extra positional and keyword args - with pytest.raises(ValueError, match="The 'b' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'b' parameter of _func must be" + ): _func(0, *["wrong", 2, 3], c=4, **{"e": 5}) - with pytest.raises(ValueError, match="The 'c' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match="The 'c' parameter of _func must be" + ): _func(0, *[1, 2, 3], c="four", **{"e": 5}) -def test_validate_params_match_error(): - """Check that an informative error is raised when there are constraints - that have no matching function paramaters - """ - - @validate_params({"a": [int], "c": [int]}) - def func(a, b): - pass - - match = r"The parameter constraints .* contain unexpected parameters {'c'}" - with pytest.raises(ValueError, match=match): - func(1, 2) - - def test_validate_params_missing_params(): """Check that no error is raised when there are parameters without constraints @@ -488,19 +487,24 @@ def test_decorate_validated_function(): # outer decorator does not interfer with validation with pytest.warns(FutureWarning, match="Function _func is deprecated"): - with pytest.raises(ValueError, match=r"The 'c' parameter of _func must be"): + with pytest.raises( + InvalidParameterError, match=r"The 'c' parameter of _func must be" + ): decorated_function(1, 2, c="wrong") def test_validate_params_method(): """Check that validate_params works with methods""" - with pytest.raises(ValueError, match="The 'a' parameter of _Class._method must be"): + with pytest.raises( + InvalidParameterError, match="The 'a' parameter of _Class._method must be" + ): _Class()._method("wrong") # validated method can be decorated with pytest.warns(FutureWarning, match="Function _deprecated_method is deprecated"): with pytest.raises( - ValueError, match="The 'a' parameter of _Class._deprecated_method must be" + InvalidParameterError, + match="The 'a' parameter of _Class._deprecated_method must be", ): _Class()._deprecated_method("wrong") @@ -510,7 +514,9 @@ def test_validate_params_estimator(): # no validation in init est = _Estimator("wrong") - with pytest.raises(ValueError, match="The 'a' parameter of _Estimator must be"): + with pytest.raises( + InvalidParameterError, match="The 'a' parameter of _Estimator must be" + ): est.fit() @@ -531,7 +537,9 @@ def f(param): f({"a": 1, "b": 2, "c": 3}) f([1, 2, 3]) - with pytest.raises(ValueError, match="The 'param' parameter") as exc_info: + with pytest.raises( + InvalidParameterError, match="The 'param' parameter" + ) as exc_info: f(param="bad") # the list option is not exposed in the error message @@ -551,7 +559,9 @@ def f(param): f("auto") f("warn") - with pytest.raises(ValueError, match="The 'param' parameter") as exc_info: + with pytest.raises( + InvalidParameterError, match="The 'param' parameter" + ) as exc_info: f(param="bad") # the "warn" option is not exposed in the error message @@ -596,7 +606,7 @@ def f(param1=None, param2=None): pass # param1 is validated - with pytest.raises(ValueError, match="The 'param1' parameter"): + with pytest.raises(InvalidParameterError, match="The 'param1' parameter"): f(param1="wrong") # param2 is not validated: any type is valid. @@ -633,3 +643,22 @@ def test_cv_objects(): assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])]) assert constraint.is_satisfied_by(None) assert not constraint.is_satisfied_by("not a CV object") + + +def test_third_party_estimator(): + """Check that the validation from a scikit-learn estimator inherited by a third + party estimator does not impose a match between the dict of constraints and the + parameters of the estimator. + """ + + class ThirdPartyEstimator(_Estimator): + def __init__(self, b): + self.b = b + super().__init__(a=0) + + def fit(self, X=None, y=None): + super().fit(X, y) + + # does not raise, even though "b" is not in the constraints dict and "a" is not + # a parameter of the estimator. + ThirdPartyEstimator(b=0).fit() diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 30e37c7330ecb..538802dd2f8a8 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -447,6 +447,27 @@ def test_check_array_pandas_na_support(pd_dtype, dtype, expected_dtype): check_array(X, force_all_finite=True) +def test_check_array_panadas_na_support_series(): + """Check check_array is correct with pd.NA in a series.""" + pd = pytest.importorskip("pandas") + + X_int64 = pd.Series([1, 2, pd.NA], dtype="Int64") + + msg = "Input contains NaN" + with pytest.raises(ValueError, match=msg): + check_array(X_int64, force_all_finite=True, ensure_2d=False) + + X_out = check_array(X_int64, force_all_finite=False, ensure_2d=False) + assert_allclose(X_out, [1, 2, np.nan]) + assert X_out.dtype == np.float64 + + X_out = check_array( + X_int64, force_all_finite=False, ensure_2d=False, dtype=np.float32 + ) + assert_allclose(X_out, [1, 2, np.nan]) + assert X_out.dtype == np.float32 + + def test_check_array_pandas_dtype_casting(): # test that data-frames with homogeneous dtype are not upcast pd = pytest.importorskip("pandas") diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 7de0fe200607b..0b5a75f8ed2bb 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -777,6 +777,13 @@ def check_array( if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig): dtype_orig = np.result_type(*dtypes_orig) + elif hasattr(array, "iloc") and hasattr(array, "dtype"): + # array is a pandas series + pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype) + if pandas_requires_conversion: + # Set to None, to convert to a np.dtype that works with array.dtype + dtype_orig = None + if dtype_numeric: if dtype_orig is not None and dtype_orig.kind == "O": # if input is object, convert to float.