diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index caae838f68f36..d55becb0c512a 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1126,6 +1126,7 @@ See the :ref:`visualizations` section of the user guide for further details. metrics.ConfusionMatrixDisplay metrics.DetCurveDisplay metrics.PrecisionRecallDisplay + metrics.PredictionErrorDisplay metrics.RocCurveDisplay calibration.CalibrationDisplay diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 89c4ad3ec2cb8..1788fc806ab53 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -2711,6 +2711,80 @@ Here are some usage examples of the :func:`d2_absolute_error_score` function:: >>> d2_absolute_error_score(y_true, y_pred) 0.0 +.. _visualization_regression_evaluation: + +Visual evaluation of regression models +-------------------------------------- + +Among methods to assess the quality of regression models, scikit-learn provides +the :class:`~sklearn.metrics.PredictionErrorDisplay` class. It allows to +visually inspect the prediction errors of a model in two different manners. + +.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_cv_predict_001.png + :target: ../auto_examples/model_selection/plot_cv_predict.html + :scale: 75 + :align: center + +The plot on the left shows the actual values vs predicted values. For a +noise-free regression task aiming to predict the (conditional) expectation of +`y`, a perfect regression model would display data points on the diagonal +defined by predicted equal to actual values. The further away from this optimal +line, the larger the error of the model. In a more realistic setting with +irreducible noise, that is, when not all the variations of `y` can be explained +by features in `X`, then the best model would lead to a cloud of points densely +arranged around the diagonal. + +Note that the above only holds when the predicted values is the expected value +of `y` given `X`. This is typically the case for regression models that +minimize the mean squared error objective function or more generally the +:ref:`mean Tweedie deviance ` for any value of its +"power" parameter. + +When plotting the predictions of an estimator that predicts a quantile +of `y` given `X`, e.g. :class:`~sklearn.linear_model.QuantileRegressor` +or any other model minimizing the :ref:`pinball loss `, a +fraction of the points are either expected to lie above or below the diagonal +depending on the estimated quantile level. + +All in all, while intuitive to read, this plot does not really inform us on +what to do to obtain a better model. + +The right-hand side plot shows the residuals (i.e. the difference between the +actual and the predicted values) vs. the predicted values. + +This plot makes it easier to visualize if the residuals follow and +`homoscedastic or heteroschedastic +`_ +distribution. + +In particular, if the true distribution of `y|X` is Poisson or Gamma +distributed, it is expected that the variance of the residuals of the optimal +model would grow with the predicted value of `E[y|X]` (either linearly for +Poisson or quadratically for Gamma). + +When fitting a linear least squares regression model (see +:class:`~sklearn.linear_mnodel.LinearRegression` and +:class:`~sklearn.linear_mnodel.Ridge`), we can use this plot to check +if some of the `model assumptions +`_ +are met, in particular that the residuals should be uncorrelated, their +expected value should be null and that their variance should be constant +(homoschedasticity). + +If this is not the case, and in particular if the residuals plot show some +banana-shaped structure, this is a hint that the model is likely mis-specified +and that non-linear feature engineering or switching to a non-linear regression +model might be useful. + +Refer to the example below to see a model evaluation that makes use of this +display. + +.. topic:: Example: + + * See :ref:`sphx_glr_auto_examples_compose_plot_transformed_target.py` for + an example on how to use :class:`~sklearn.metrics.PredictionErrorDisplay` + to visualize the prediction quality improvement of a regression model + obtained by transforming the target before learning. .. _clustering_metrics: diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 7755b6f74d1f9..f692fd8efd1df 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -86,5 +86,6 @@ Display Objects metrics.ConfusionMatrixDisplay metrics.DetCurveDisplay metrics.PrecisionRecallDisplay + metrics.PredictionErrorDisplay metrics.RocCurveDisplay model_selection.LearningCurveDisplay diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6003fa7f20407..759bbf411fe19 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -504,6 +504,13 @@ Changelog of a binary classification problem. :pr:`22518` by :user:`Arturo Amor `. +- |Feature| Add :class:`metrics.PredictionErrorDisplay` to plot residuals vs + predicted and actual vs predicted to qualitatively assess the behavior of a + regressor. The display can be created with the class methods + :func:`metrics.PredictionErrorDisplay.from_estimator` and + :func:`metrics.PredictionErrorDisplay.from_predictions`. :pr:`18020` by + :user:`Guillaume Lemaitre `. + - |Feature| :func:`metrics.roc_auc_score` now supports micro-averaging (`average="micro"`) for the One-vs-Rest multiclass case (`multi_class="ovr"`). :pr:`24338` by :user:`Arturo Amor `. diff --git a/examples/compose/plot_transformed_target.py b/examples/compose/plot_transformed_target.py index 2454affb349cf..7e45d8b6c1c0f 100644 --- a/examples/compose/plot_transformed_target.py +++ b/examples/compose/plot_transformed_target.py @@ -15,20 +15,12 @@ # Author: Guillaume Lemaitre # License: BSD 3 clause -import numpy as np -import matplotlib.pyplot as plt - -from sklearn.datasets import make_regression -from sklearn.model_selection import train_test_split -from sklearn.linear_model import RidgeCV -from sklearn.compose import TransformedTargetRegressor -from sklearn.metrics import median_absolute_error, r2_score +print(__doc__) # %% # Synthetic example -############################################################################## - -# %% +################### +# # A synthetic random regression dataset is generated. The targets ``y`` are # modified by: # @@ -40,14 +32,18 @@ # Therefore, a logarithmic (`np.log1p`) and an exponential function # (`np.expm1`) will be used to transform the targets before training a linear # regression model and using it for prediction. +import numpy as np +from sklearn.datasets import make_regression -X, y = make_regression(n_samples=10000, noise=100, random_state=0) +X, y = make_regression(n_samples=10_000, noise=100, random_state=0) y = np.expm1((y + abs(y.min())) / 200) y_trans = np.log1p(y) # %% # Below we plot the probability density functions of the target # before and after applying the logarithmic functions. +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split f, (ax0, ax1) = plt.subplots(1, 2) @@ -62,8 +58,8 @@ ax1.set_xlabel("Target") ax1.set_title("Transformed target distribution") -f.suptitle("Synthetic data", y=0.06, x=0.53) -f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) +f.suptitle("Synthetic data", y=1.05) +plt.tight_layout() X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) @@ -72,168 +68,165 @@ # non-linearity, the model trained will not be precise during # prediction. Subsequently, a logarithmic function is used to linearize the # targets, allowing better prediction even with a similar linear model as -# reported by the median absolute error (MAE). +# reported by the median absolute error (MedAE). +from sklearn.metrics import median_absolute_error, r2_score + + +def compute_score(y_true, y_pred): + return { + "R2": f"{r2_score(y_true, y_pred):.3f}", + "MedAE": f"{median_absolute_error(y_true, y_pred):.3f}", + } + + +# %% +from sklearn.compose import TransformedTargetRegressor +from sklearn.linear_model import RidgeCV +from sklearn.metrics import PredictionErrorDisplay f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) -# Use linear model -regr = RidgeCV() -regr.fit(X_train, y_train) -y_pred = regr.predict(X_test) -# Plot results -ax0.scatter(y_test, y_pred) -ax0.plot([0, 2000], [0, 2000], "--k") -ax0.set_ylabel("Target predicted") -ax0.set_xlabel("True Target") -ax0.set_title("Ridge regression \n without target transformation") -ax0.text( - 100, - 1750, - r"$R^2$=%.2f, MAE=%.2f" - % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)), -) -ax0.set_xlim([0, 2000]) -ax0.set_ylim([0, 2000]) -# Transform targets and use same linear model -regr_trans = TransformedTargetRegressor( + +ridge_cv = RidgeCV().fit(X_train, y_train) +y_pred_ridge = ridge_cv.predict(X_test) + +ridge_cv_with_trans_target = TransformedTargetRegressor( regressor=RidgeCV(), func=np.log1p, inverse_func=np.expm1 +).fit(X_train, y_train) +y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test) + +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge, + kind="actual_vs_predicted", + ax=ax0, + scatter_kwargs={"alpha": 0.5}, ) -regr_trans.fit(X_train, y_train) -y_pred = regr_trans.predict(X_test) - -ax1.scatter(y_test, y_pred) -ax1.plot([0, 2000], [0, 2000], "--k") -ax1.set_ylabel("Target predicted") -ax1.set_xlabel("True Target") -ax1.set_title("Ridge regression \n with target transformation") -ax1.text( - 100, - 1750, - r"$R^2$=%.2f, MAE=%.2f" - % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)), +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge_with_trans_target, + kind="actual_vs_predicted", + ax=ax1, + scatter_kwargs={"alpha": 0.5}, ) -ax1.set_xlim([0, 2000]) -ax1.set_ylim([0, 2000]) -f.suptitle("Synthetic data", y=0.035) -f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) +# Add the score in the legend of each axis +for ax, y_pred in zip([ax0, ax1], [y_pred_ridge, y_pred_ridge_with_trans_target]): + for name, score in compute_score(y_test, y_pred).items(): + ax.plot([], [], " ", label=f"{name}={score}") + ax.legend(loc="upper left") + +ax0.set_title("Ridge regression \n without target transformation") +ax1.set_title("Ridge regression \n with target transformation") +f.suptitle("Synthetic data", y=1.05) +plt.tight_layout() # %% # Real-world data set -############################################################################### +##################### # # In a similar manner, the Ames housing data set is used to show the impact # of transforming the targets before learning a model. In this example, the # target to be predicted is the selling price of each house. - from sklearn.datasets import fetch_openml -from sklearn.preprocessing import QuantileTransformer, quantile_transform +from sklearn.preprocessing import quantile_transform ames = fetch_openml(name="house_prices", as_frame=True, parser="pandas") # Keep only numeric columns X = ames.data.select_dtypes(np.number) # Remove columns with NaN or Inf values X = X.drop(columns=["LotFrontage", "GarageYrBlt", "MasVnrArea"]) -y = ames.target +# Let the price be in k$ +y = ames.target / 1000 y_trans = quantile_transform( y.to_frame(), n_quantiles=900, output_distribution="normal", copy=True ).squeeze() + # %% # A :class:`~sklearn.preprocessing.QuantileTransformer` is used to normalize # the target distribution before applying a # :class:`~sklearn.linear_model.RidgeCV` model. - f, (ax0, ax1) = plt.subplots(1, 2) ax0.hist(y, bins=100, density=True) ax0.set_ylabel("Probability") ax0.set_xlabel("Target") -ax0.text(s="Target distribution", x=1.2e5, y=9.8e-6, fontsize=12) -ax0.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) +ax0.set_title("Target distribution") ax1.hist(y_trans, bins=100, density=True) ax1.set_ylabel("Probability") ax1.set_xlabel("Target") -ax1.text(s="Transformed target distribution", x=-6.8, y=0.479, fontsize=12) +ax1.set_title("Transformed target distribution") -f.suptitle("Ames housing data: selling price", y=0.04) -f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) +f.suptitle("Ames housing data: selling price", y=1.05) +plt.tight_layout() +# %% X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) # %% # The effect of the transformer is weaker than on the synthetic data. However, # the transformation results in an increase in :math:`R^2` and large decrease -# of the MAE. The residual plot (predicted target - true target vs predicted +# of the MedAE. The residual plot (predicted target - true target vs predicted # target) without target transformation takes on a curved, 'reverse smile' # shape due to residual values that vary depending on the value of predicted # target. With target transformation, the shape is more linear indicating # better model fit. +from sklearn.preprocessing import QuantileTransformer f, (ax0, ax1) = plt.subplots(2, 2, sharey="row", figsize=(6.5, 8)) -regr = RidgeCV() -regr.fit(X_train, y_train) -y_pred = regr.predict(X_test) - -ax0[0].scatter(y_pred, y_test, s=8) -ax0[0].plot([0, 7e5], [0, 7e5], "--k") -ax0[0].set_ylabel("True target") -ax0[0].set_xlabel("Predicted target") -ax0[0].text( - s="Ridge regression \n without target transformation", - x=-5e4, - y=8e5, - fontsize=12, - multialignment="center", -) -ax0[0].text( - 3e4, - 64e4, - r"$R^2$=%.2f, MAE=%.2f" - % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)), -) -ax0[0].set_xlim([0, 7e5]) -ax0[0].set_ylim([0, 7e5]) -ax0[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) +ridge_cv = RidgeCV().fit(X_train, y_train) +y_pred_ridge = ridge_cv.predict(X_test) -ax1[0].scatter(y_pred, (y_pred - y_test), s=8) -ax1[0].set_ylabel("Residual") -ax1[0].set_xlabel("Predicted target") -ax1[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) - -regr_trans = TransformedTargetRegressor( +ridge_cv_with_trans_target = TransformedTargetRegressor( regressor=RidgeCV(), transformer=QuantileTransformer(n_quantiles=900, output_distribution="normal"), +).fit(X_train, y_train) +y_pred_ridge_with_trans_target = ridge_cv_with_trans_target.predict(X_test) + +# plot the actual vs predicted values +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge, + kind="actual_vs_predicted", + ax=ax0[0], + scatter_kwargs={"alpha": 0.5}, ) -regr_trans.fit(X_train, y_train) -y_pred = regr_trans.predict(X_test) - -ax0[1].scatter(y_pred, y_test, s=8) -ax0[1].plot([0, 7e5], [0, 7e5], "--k") -ax0[1].set_ylabel("True target") -ax0[1].set_xlabel("Predicted target") -ax0[1].text( - s="Ridge regression \n with target transformation", - x=-5e4, - y=8e5, - fontsize=12, - multialignment="center", -) -ax0[1].text( - 3e4, - 64e4, - r"$R^2$=%.2f, MAE=%.2f" - % (r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)), +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge_with_trans_target, + kind="actual_vs_predicted", + ax=ax0[1], + scatter_kwargs={"alpha": 0.5}, ) -ax0[1].set_xlim([0, 7e5]) -ax0[1].set_ylim([0, 7e5]) -ax0[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) - -ax1[1].scatter(y_pred, (y_pred - y_test), s=8) -ax1[1].set_ylabel("Residual") -ax1[1].set_xlabel("Predicted target") -ax1[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) -f.suptitle("Ames housing data: selling price", y=0.035) +# Add the score in the legend of each axis +for ax, y_pred in zip([ax0[0], ax0[1]], [y_pred_ridge, y_pred_ridge_with_trans_target]): + for name, score in compute_score(y_test, y_pred).items(): + ax.plot([], [], " ", label=f"{name}={score}") + ax.legend(loc="upper left") + +ax0[0].set_title("Ridge regression \n without target transformation") +ax0[1].set_title("Ridge regression \n with target transformation") + +# plot the residuals vs the predicted values +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge, + kind="residual_vs_predicted", + ax=ax1[0], + scatter_kwargs={"alpha": 0.5}, +) +PredictionErrorDisplay.from_predictions( + y_test, + y_pred_ridge_with_trans_target, + kind="residual_vs_predicted", + ax=ax1[1], + scatter_kwargs={"alpha": 0.5}, +) +ax1[0].set_title("Ridge regression \n without target transformation") +ax1[1].set_title("Ridge regression \n with target transformation") +f.suptitle("Ames housing data: selling price", y=1.05) +plt.tight_layout() plt.show() diff --git a/examples/ensemble/plot_stack_predictors.py b/examples/ensemble/plot_stack_predictors.py index df776c873761f..56a82ded5b725 100644 --- a/examples/ensemble/plot_stack_predictors.py +++ b/examples/ensemble/plot_stack_predictors.py @@ -22,9 +22,9 @@ # %% # Download the dataset -############################################################################## +###################### # -# We will use `Ames Housing`_ dataset which was first compiled by Dean De Cock +# We will use the `Ames Housing`_ dataset which was first compiled by Dean De Cock # and became better known after it was used in Kaggle challenge. It is a set # of 1460 residential homes in Ames, Iowa, each described by 80 features. We # will use it to predict the final logarithmic price of the houses. In this @@ -82,10 +82,9 @@ def load_ames_housing(): X, y = load_ames_housing() - # %% # Make pipeline to preprocess the data -############################################################################## +###################################### # # Before we can use Ames dataset we still need to do some preprocessing. # First, we will select the categorical and numerical columns of the dataset to @@ -147,7 +146,7 @@ def load_ames_housing(): # %% # Stack of predictors on a single data set -############################################################################## +########################################## # # It is sometimes tedious to find the model which will best perform on a given # dataset. Stacking provide an alternative by combining the outputs of several @@ -199,73 +198,54 @@ def load_ames_housing(): # %% # Measure and plot the results -############################################################################## +############################## # # Now we can use Ames Housing dataset to make the predictions. We check the # performance of each individual predictor as well as of the stack of the # regressors. -# -# The function ``plot_regression_results`` is used to plot the predicted and -# true targets. import time import matplotlib.pyplot as plt +from sklearn.metrics import PredictionErrorDisplay from sklearn.model_selection import cross_validate, cross_val_predict - -def plot_regression_results(ax, y_true, y_pred, title, scores, elapsed_time): - """Scatter plot of the predicted vs true targets.""" - ax.plot( - [y_true.min(), y_true.max()], [y_true.min(), y_true.max()], "--r", linewidth=2 - ) - ax.scatter(y_true, y_pred, alpha=0.2) - - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.get_xaxis().tick_bottom() - ax.get_yaxis().tick_left() - ax.spines["left"].set_position(("outward", 10)) - ax.spines["bottom"].set_position(("outward", 10)) - ax.set_xlim([y_true.min(), y_true.max()]) - ax.set_ylim([y_true.min(), y_true.max()]) - ax.set_xlabel("Measured") - ax.set_ylabel("Predicted") - extra = plt.Rectangle( - (0, 0), 0, 0, fc="w", fill=False, edgecolor="none", linewidth=0 - ) - ax.legend([extra], [scores], loc="upper left") - title = title + "\n Evaluation in {:.2f} seconds".format(elapsed_time) - ax.set_title(title) - - fig, axs = plt.subplots(2, 2, figsize=(9, 7)) axs = np.ravel(axs) for ax, (name, est) in zip( axs, estimators + [("Stacking Regressor", stacking_regressor)] ): + scorers = {"R2": "r2", "MAE": "neg_mean_absolute_error"} + start_time = time.time() - score = cross_validate( - est, X, y, scoring=["r2", "neg_mean_absolute_error"], n_jobs=2, verbose=0 + scores = cross_validate( + est, X, y, scoring=list(scorers.values()), n_jobs=-1, verbose=0 ) elapsed_time = time.time() - start_time - y_pred = cross_val_predict(est, X, y, n_jobs=2, verbose=0) - - plot_regression_results( - ax, - y, - y_pred, - name, - (r"$R^2={:.2f} \pm {:.2f}$" + "\n" + r"$MAE={:.2f} \pm {:.2f}$").format( - np.mean(score["test_r2"]), - np.std(score["test_r2"]), - -np.mean(score["test_neg_mean_absolute_error"]), - np.std(score["test_neg_mean_absolute_error"]), - ), - elapsed_time, + y_pred = cross_val_predict(est, X, y, n_jobs=-1, verbose=0) + scores = { + key: ( + f"{np.abs(np.mean(scores[f'test_{value}'])):.2f} +- " + f"{np.std(scores[f'test_{value}']):.2f}" + ) + for key, value in scorers.items() + } + + display = PredictionErrorDisplay.from_predictions( + y_true=y, + y_pred=y_pred, + kind="actual_vs_predicted", + ax=ax, + scatter_kwargs={"alpha": 0.2, "color": "tab:blue"}, + line_kwargs={"color": "tab:red"}, ) + ax.set_title(f"{name}\nEvaluation in {elapsed_time:.2f} seconds") + + for name, score in scores.items(): + ax.plot([], [], " ", label=f"{name}: {score}") + ax.legend(loc="upper left") plt.suptitle("Single predictors versus stacked predictors") plt.tight_layout() diff --git a/examples/inspection/plot_linear_model_coefficient_interpretation.py b/examples/inspection/plot_linear_model_coefficient_interpretation.py index 5ee2677570635..b9de243666fb1 100644 --- a/examples/inspection/plot_linear_model_coefficient_interpretation.py +++ b/examples/inspection/plot_linear_model_coefficient_interpretation.py @@ -66,6 +66,7 @@ # Our target for prediction: the wage. # Wages are described as floating-point number in dollars per hour. +# %% y = survey.target.values.ravel() survey.target.head() @@ -168,30 +169,31 @@ # for example, the median absolute error of the model. from sklearn.metrics import median_absolute_error +from sklearn.metrics import PredictionErrorDisplay -y_pred = model.predict(X_train) - -mae = median_absolute_error(y_train, y_pred) -string_score = f"MAE on training set: {mae:.2f} $/hour" +mae_train = median_absolute_error(y_train, model.predict(X_train)) y_pred = model.predict(X_test) -mae = median_absolute_error(y_test, y_pred) -string_score += f"\nMAE on testing set: {mae:.2f} $/hour" +mae_test = median_absolute_error(y_test, y_pred) +scores = { + "MedAE on training set": f"{mae_train:.2f} $/hour", + "MedAE on testing set": f"{mae_test:.2f} $/hour", +} # %% -fig, ax = plt.subplots(figsize=(5, 5)) -plt.scatter(y_test, y_pred) -ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="red") -plt.text(3, 20, string_score) -plt.title("Ridge model, small regularization") -plt.ylabel("Model predictions") -plt.xlabel("Truths") -plt.xlim([0, 27]) -_ = plt.ylim([0, 27]) +_, ax = plt.subplots(figsize=(5, 5)) +display = PredictionErrorDisplay.from_predictions( + y_test, y_pred, kind="actual_vs_predicted", ax=ax, scatter_kwargs={"alpha": 0.5} +) +ax.set_title("Ridge model, small regularization") +for name, score in scores.items(): + ax.plot([], [], " ", label=f"{name}: {score}") +ax.legend(loc="upper left") +plt.tight_layout() # %% # The model learnt is far from being a good model making accurate predictions: # this is obvious when looking at the plot above, where good predictions -# should lie on the red line. +# should lie on the black dashed line. # # In the following section, we will interpret the coefficients of the model. # While we do so, we should keep in mind that any conclusion we draw is @@ -437,25 +439,23 @@ # model using, for example, the median absolute error of the model and the R # squared coefficient. -y_pred = model.predict(X_train) -mae = median_absolute_error(y_train, y_pred) -string_score = f"MAE on training set: {mae:.2f} $/hour" +mae_train = median_absolute_error(y_train, model.predict(X_train)) y_pred = model.predict(X_test) -mae = median_absolute_error(y_test, y_pred) -string_score += f"\nMAE on testing set: {mae:.2f} $/hour" - -# %% -fig, ax = plt.subplots(figsize=(6, 6)) -plt.scatter(y_test, y_pred) -ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="red") - -plt.text(3, 20, string_score) - -plt.title("Ridge model, small regularization, normalized variables") -plt.ylabel("Model predictions") -plt.xlabel("Truths") -plt.xlim([0, 27]) -_ = plt.ylim([0, 27]) +mae_test = median_absolute_error(y_test, y_pred) +scores = { + "MedAE on training set": f"{mae_train:.2f} $/hour", + "MedAE on testing set": f"{mae_test:.2f} $/hour", +} + +_, ax = plt.subplots(figsize=(5, 5)) +display = PredictionErrorDisplay.from_predictions( + y_test, y_pred, kind="actual_vs_predicted", ax=ax, scatter_kwargs={"alpha": 0.5} +) +ax.set_title("Ridge model, small regularization") +for name, score in scores.items(): + ax.plot([], [], " ", label=f"{name}: {score}") +ax.legend(loc="upper left") +plt.tight_layout() # %% # For the coefficient analysis, scaling is not needed this time because it @@ -533,26 +533,23 @@ # %% # Then we check the quality of the predictions. - -y_pred = model.predict(X_train) -mae = median_absolute_error(y_train, y_pred) -string_score = f"MAE on training set: {mae:.2f} $/hour" +mae_train = median_absolute_error(y_train, model.predict(X_train)) y_pred = model.predict(X_test) -mae = median_absolute_error(y_test, y_pred) -string_score += f"\nMAE on testing set: {mae:.2f} $/hour" - -# %% -fig, ax = plt.subplots(figsize=(6, 6)) -plt.scatter(y_test, y_pred) -ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="red") - -plt.text(3, 20, string_score) - -plt.title("Ridge model, optimum regularization, normalized variables") -plt.ylabel("Model predictions") -plt.xlabel("Truths") -plt.xlim([0, 27]) -_ = plt.ylim([0, 27]) +mae_test = median_absolute_error(y_test, y_pred) +scores = { + "MedAE on training set": f"{mae_train:.2f} $/hour", + "MedAE on testing set": f"{mae_test:.2f} $/hour", +} + +_, ax = plt.subplots(figsize=(5, 5)) +display = PredictionErrorDisplay.from_predictions( + y_test, y_pred, kind="actual_vs_predicted", ax=ax, scatter_kwargs={"alpha": 0.5} +) +ax.set_title("Ridge model, optimum regularization") +for name, score in scores.items(): + ax.plot([], [], " ", label=f"{name}: {score}") +ax.legend(loc="upper left") +plt.tight_layout() # %% # The ability to reproduce the data of the regularized model is similar to @@ -640,25 +637,23 @@ # %% # Then we check the quality of the predictions. -y_pred = model.predict(X_train) -mae = median_absolute_error(y_train, y_pred) -string_score = f"MAE on training set: {mae:.2f} $/hour" +mae_train = median_absolute_error(y_train, model.predict(X_train)) y_pred = model.predict(X_test) -mae = median_absolute_error(y_test, y_pred) -string_score += f"\nMAE on testing set: {mae:.2f} $/hour" - -# %% -fig, ax = plt.subplots(figsize=(6, 6)) -plt.scatter(y_test, y_pred) -ax.plot([0, 1], [0, 1], transform=ax.transAxes, ls="--", c="red") - -plt.text(3, 20, string_score) - -plt.title("Lasso model, regularization, normalized variables") -plt.ylabel("Model predictions") -plt.xlabel("Truths") -plt.xlim([0, 27]) -_ = plt.ylim([0, 27]) +mae_test = median_absolute_error(y_test, y_pred) +scores = { + "MedAE on training set": f"{mae_train:.2f} $/hour", + "MedAE on testing set": f"{mae_test:.2f} $/hour", +} + +_, ax = plt.subplots(figsize=(6, 6)) +display = PredictionErrorDisplay.from_predictions( + y_test, y_pred, kind="actual_vs_predicted", ax=ax, scatter_kwargs={"alpha": 0.5} +) +ax.set_title("Lasso model, optimum regularization") +for name, score in scores.items(): + ax.plot([], [], " ", label=f"{name}: {score}") +ax.legend(loc="upper left") +plt.tight_layout() # %% # For our dataset, again the model is not very predictive. diff --git a/examples/model_selection/plot_cv_predict.py b/examples/model_selection/plot_cv_predict.py index 82ef0b8b81ae6..7fd843c535c85 100644 --- a/examples/model_selection/plot_cv_predict.py +++ b/examples/model_selection/plot_cv_predict.py @@ -4,26 +4,75 @@ ==================================== This example shows how to use -:func:`~sklearn.model_selection.cross_val_predict` to visualize prediction +:func:`~sklearn.model_selection.cross_val_predict` together with +:class:`~sklearn.metrics.PredictionErrorDisplay` to visualize prediction errors. - """ -from sklearn import datasets +# %% +# We will load the diabetes dataset and create an instance of a linear +# regression model. +from sklearn.datasets import load_diabetes +from sklearn.linear_model import LinearRegression + +X, y = load_diabetes(return_X_y=True) +lr = LinearRegression() + +# %% +# :func:`~sklearn.model_selection.cross_val_predict` returns an array of the +# same size of `y` where each entry is a prediction obtained by cross +# validation. from sklearn.model_selection import cross_val_predict -from sklearn import linear_model -import matplotlib.pyplot as plt -lr = linear_model.LinearRegression() -X, y = datasets.load_diabetes(return_X_y=True) +y_pred = cross_val_predict(lr, X, y, cv=10) -# cross_val_predict returns an array of the same size as `y` where each entry -# is a prediction obtained by cross validation: -predicted = cross_val_predict(lr, X, y, cv=10) +# %% +# Since `cv=10`, it means that we trained 10 models and each model was +# used to predict on one of the 10 folds. We can now use the +# :class:`~sklearn.metrics.PredictionErrorDisplay` to visualize the +# prediction errors. +# +# On the left axis, we plot the observed values :math:`y` vs. the predicted +# values :math:`\hat{y}` given by the models. On the right axis, we plot the +# residuals (i.e. the difference between the observed values and the predicted +# values) vs. the predicted values. +import matplotlib.pyplot as plt +from sklearn.metrics import PredictionErrorDisplay -fig, ax = plt.subplots() -ax.scatter(y, predicted, edgecolors=(0, 0, 0)) -ax.plot([y.min(), y.max()], [y.min(), y.max()], "k--", lw=4) -ax.set_xlabel("Measured") -ax.set_ylabel("Predicted") +fig, axs = plt.subplots(ncols=2, figsize=(8, 4)) +PredictionErrorDisplay.from_predictions( + y, + y_pred=y_pred, + kind="actual_vs_predicted", + subsample=100, + ax=axs[0], + random_state=0, +) +axs[0].set_title("Actual vs. Predicted values") +PredictionErrorDisplay.from_predictions( + y, + y_pred=y_pred, + kind="residual_vs_predicted", + subsample=100, + ax=axs[1], + random_state=0, +) +axs[1].set_title("Residuals vs. Predicted Values") +fig.suptitle("Plotting cross-validated predictions") +plt.tight_layout() plt.show() + +# %% +# It is important to note that we used +# :func:`~sklearn.model_selection.cross_val_predict` for visualization +# purpose only in this example. +# +# It would be problematic to +# quantitatively assess the model performance by computing a single +# performance metric from the concatenated predictions returned by +# :func:`~sklearn.model_selection.cross_val_predict` +# when the different CV folds vary by size and distributions. +# +# In is recommended to compute per-fold performance metrics using: +# :func:`~sklearn.model_selection.cross_val_score` or +# :func:`~sklearn.model_selection.cross_validate` instead. diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 25bb540d65a91..4224bfbb9c04c 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -92,8 +92,8 @@ from ._plot.det_curve import DetCurveDisplay from ._plot.roc_curve import RocCurveDisplay from ._plot.precision_recall_curve import PrecisionRecallDisplay - from ._plot.confusion_matrix import ConfusionMatrixDisplay +from ._plot.regression import PredictionErrorDisplay __all__ = [ @@ -163,6 +163,7 @@ "precision_recall_curve", "precision_recall_fscore_support", "precision_score", + "PredictionErrorDisplay", "r2_score", "rand_score", "recall_score", diff --git a/sklearn/metrics/_plot/regression.py b/sklearn/metrics/_plot/regression.py new file mode 100644 index 0000000000000..46440c3e133b1 --- /dev/null +++ b/sklearn/metrics/_plot/regression.py @@ -0,0 +1,406 @@ +import numbers + +import numpy as np + +from ...utils import check_matplotlib_support +from ...utils import check_random_state +from ...utils import _safe_indexing + + +class PredictionErrorDisplay: + """Visualization of the prediction error of a regression model. + + This tool can display "residuals vs predicted" or "actual vs predicted" + using scatter plots to qualitatively assess the behavior of a regressor, + preferably on held-out data points. + + See the details in the docstrings of + :func:`~sklearn.metrics.PredictionErrorDisplay.from_estimator` or + :func:`~sklearn.metrics.PredictionErrorDisplay.from_predictions` to + create a visualizer. All parameters are stored as attributes. + + For general information regarding `scikit-learn` visualization tools, read + more in the :ref:`Visualization Guide `. + For details regarding interpreting these plots, refer to the + :ref:`Model Evaluation Guide `. + + .. versionadded:: 1.2 + + Parameters + ---------- + y_true : ndarray of shape (n_samples,) + True values. + + y_pred : ndarray of shape (n_samples,) + Prediction values. + + Attributes + ---------- + line_ : matplotlib Artist + Optimal line representing `y_true == y_pred`. Therefore, it is a + diagonal line for `kind="predictions"` and a horizontal line for + `kind="residuals"`. + + errors_lines_ : matplotlib Artist or None + Residual lines. If `with_errors=False`, then it is set to `None`. + + scatter_ : matplotlib Artist + Scatter data points. + + ax_ : matplotlib Axes + Axes with the different matplotlib axis. + + figure_ : matplotlib Figure + Figure containing the scatter and lines. + + See Also + -------- + PredictionErrorDisplay.from_estimator : Prediction error visualization + given an estimator and some data. + PredictionErrorDisplay.from_predictions : Prediction error visualization + given the true and predicted targets. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.linear_model import Ridge + >>> from sklearn.metrics import PredictionErrorDisplay + >>> X, y = load_diabetes(return_X_y=True) + >>> ridge = Ridge().fit(X, y) + >>> y_pred = ridge.predict(X) + >>> display = PredictionErrorDisplay(y_true=y, y_pred=y_pred) + >>> display.plot() + <...> + >>> plt.show() + """ + + def __init__(self, *, y_true, y_pred): + self.y_true = y_true + self.y_pred = y_pred + + def plot( + self, + ax=None, + *, + kind="residual_vs_predicted", + scatter_kwargs=None, + line_kwargs=None, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + Returns + ------- + display : :class:`~sklearn.metrics.plot.PredictionErrorDisplay` + Object that stores computed values. + """ + check_matplotlib_support(f"{self.__class__.__name__}.plot") + + expected_kind = ("actual_vs_predicted", "residual_vs_predicted") + if kind not in expected_kind: + raise ValueError( + f"`kind` must be one of {', '.join(expected_kind)}. " + f"Got {kind!r} instead." + ) + + import matplotlib.pyplot as plt + + if scatter_kwargs is None: + scatter_kwargs = {} + if line_kwargs is None: + line_kwargs = {} + + default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8} + default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"} + + scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs} + line_kwargs = {**default_line_kwargs, **line_kwargs} + + if ax is None: + _, ax = plt.subplots() + + if kind == "actual_vs_predicted": + max_value = max(np.max(self.y_true), np.max(self.y_pred)) + min_value = min(np.min(self.y_true), np.min(self.y_pred)) + self.line_ = ax.plot( + [min_value, max_value], [min_value, max_value], **line_kwargs + )[0] + + x_data, y_data = self.y_pred, self.y_true + xlabel, ylabel = "Predicted values", "Actual values" + + self.scatter_ = ax.scatter(x_data, y_data, **scatter_kwargs) + + # force to have a squared axis + ax.set_aspect("equal", adjustable="datalim") + ax.set_xticks(np.linspace(min_value, max_value, num=5)) + ax.set_yticks(np.linspace(min_value, max_value, num=5)) + else: # kind == "residual_vs_predicted" + self.line_ = ax.plot( + [np.min(self.y_pred), np.max(self.y_pred)], + [0, 0], + **line_kwargs, + )[0] + self.scatter_ = ax.scatter( + self.y_pred, self.y_true - self.y_pred, **scatter_kwargs + ) + xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + + ax.set(xlabel=xlabel, ylabel=ylabel) + + self.ax_ = ax + self.figure_ = ax.figure + + return self + + @classmethod + def from_estimator( + cls, + estimator, + X, + y, + *, + kind="residual_vs_predicted", + subsample=1_000, + random_state=None, + ax=None, + scatter_kwargs=None, + line_kwargs=None, + ): + """Plot the prediction error given a regressor and some data. + + For general information regarding `scikit-learn` visualization tools, + read more in the :ref:`Visualization Guide `. + For details regarding interpreting these plots, refer to the + :ref:`Model Evaluation Guide `. + + .. versionadded:: 1.2 + + Parameters + ---------- + estimator : estimator instance + Fitted regressor or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a regressor. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, a 1000 samples or less will be displayed. + + random_state : int or RandomState, default=None + Controls the randomness when `subsample` is not `None`. + See :term:`Glossary ` for details. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + Returns + ------- + display : :class:`~sklearn.metrics.PredictionErrorDisplay` + Object that stores the computed values. + + See Also + -------- + PredictionErrorDisplay : Prediction error visualization for regression. + PredictionErrorDisplay.from_predictions : Prediction error visualization + given the true and predicted targets. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.linear_model import Ridge + >>> from sklearn.metrics import PredictionErrorDisplay + >>> X, y = load_diabetes(return_X_y=True) + >>> ridge = Ridge().fit(X, y) + >>> disp = PredictionErrorDisplay.from_estimator(ridge, X, y) + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_estimator") + + y_pred = estimator.predict(X) + + return cls.from_predictions( + y_true=y, + y_pred=y_pred, + kind=kind, + subsample=subsample, + random_state=random_state, + ax=ax, + scatter_kwargs=scatter_kwargs, + line_kwargs=line_kwargs, + ) + + @classmethod + def from_predictions( + cls, + y_true, + y_pred, + *, + kind="residual_vs_predicted", + subsample=1_000, + random_state=None, + ax=None, + scatter_kwargs=None, + line_kwargs=None, + ): + """Plot the prediction error given the true and predicted targets. + + For general information regarding `scikit-learn` visualization tools, + read more in the :ref:`Visualization Guide `. + For details regarding interpreting these plots, refer to the + :ref:`Model Evaluation Guide `. + + .. versionadded:: 1.2 + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + + y_pred : array-like of shape (n_samples,) + Predicted target values. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, a 1000 samples or less will be displayed. + + random_state : int or RandomState, default=None + Controls the randomness when `subsample` is not `None`. + See :term:`Glossary ` for details. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + Returns + ------- + display : :class:`~sklearn.metrics.PredictionErrorDisplay` + Object that stores the computed values. + + See Also + -------- + PredictionErrorDisplay : Prediction error visualization for regression. + PredictionErrorDisplay.from_estimator : Prediction error visualization + given an estimator and some data. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.linear_model import Ridge + >>> from sklearn.metrics import PredictionErrorDisplay + >>> X, y = load_diabetes(return_X_y=True) + >>> ridge = Ridge().fit(X, y) + >>> y_pred = ridge.predict(X) + >>> disp = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred) + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + random_state = check_random_state(random_state) + + n_samples = len(y_true) + if isinstance(subsample, numbers.Integral): + if subsample <= 0: + raise ValueError( + f"When an integer, subsample={subsample} should be positive." + ) + elif isinstance(subsample, numbers.Real): + if subsample <= 0 or subsample >= 1: + raise ValueError( + f"When a floating-point, subsample={subsample} should" + " be in the (0, 1) range." + ) + subsample = int(n_samples * subsample) + + if subsample is not None and subsample < n_samples: + indices = random_state.choice(np.arange(n_samples), size=subsample) + y_true = _safe_indexing(y_true, indices, axis=0) + y_pred = _safe_indexing(y_pred, indices, axis=0) + + viz = PredictionErrorDisplay( + y_true=y_true, + y_pred=y_pred, + ) + + return viz.plot( + ax=ax, + kind=kind, + scatter_kwargs=scatter_kwargs, + line_kwargs=line_kwargs, + ) diff --git a/sklearn/metrics/_plot/tests/test_predict_error_display.py b/sklearn/metrics/_plot/tests/test_predict_error_display.py new file mode 100644 index 0000000000000..3d3833d825360 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_predict_error_display.py @@ -0,0 +1,163 @@ +import pytest + +from numpy.testing import assert_allclose + +from sklearn.datasets import load_diabetes +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import Ridge + +from sklearn.metrics import PredictionErrorDisplay + +X, y = load_diabetes(return_X_y=True) + + +@pytest.fixture +def regressor_fitted(): + return Ridge().fit(X, y) + + +@pytest.mark.parametrize( + "regressor, params, err_type, err_msg", + [ + ( + Ridge().fit(X, y), + {"subsample": -1}, + ValueError, + "When an integer, subsample=-1 should be", + ), + ( + Ridge().fit(X, y), + {"subsample": 20.0}, + ValueError, + "When a floating-point, subsample=20.0 should be", + ), + ( + Ridge().fit(X, y), + {"subsample": -20.0}, + ValueError, + "When a floating-point, subsample=-20.0 should be", + ), + ( + Ridge().fit(X, y), + {"kind": "xxx"}, + ValueError, + "`kind` must be one of", + ), + ], +) +@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"]) +def test_prediction_error_display_raise_error( + pyplot, class_method, regressor, params, err_type, err_msg +): + """Check that we raise the proper error when making the parameters + # validation.""" + with pytest.raises(err_type, match=err_msg): + if class_method == "from_estimator": + PredictionErrorDisplay.from_estimator(regressor, X, y, **params) + else: + y_pred = regressor.predict(X) + PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred, **params) + + +def test_from_estimator_not_fitted(pyplot): + """Check that we raise a `NotFittedError` when the passed regressor is not + fit.""" + regressor = Ridge() + with pytest.raises(NotFittedError, match="is not fitted yet."): + PredictionErrorDisplay.from_estimator(regressor, X, y) + + +@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("kind", ["actual_vs_predicted", "residual_vs_predicted"]) +def test_prediction_error_display(pyplot, regressor_fitted, class_method, kind): + """Check the default behaviour of the display.""" + if class_method == "from_estimator": + display = PredictionErrorDisplay.from_estimator( + regressor_fitted, X, y, kind=kind + ) + else: + y_pred = regressor_fitted.predict(X) + display = PredictionErrorDisplay.from_predictions( + y_true=y, y_pred=y_pred, kind=kind + ) + + if kind == "actual_vs_predicted": + assert_allclose(display.line_.get_xdata(), display.line_.get_ydata()) + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Actual values" + assert display.line_ is not None + else: + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Residuals (actual - predicted)" + assert display.line_ is not None + + assert display.ax_.get_legend() is None + + +@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "subsample, expected_size", + [(5, 5), (0.1, int(X.shape[0] * 0.1)), (None, X.shape[0])], +) +def test_plot_prediction_error_subsample( + pyplot, regressor_fitted, class_method, subsample, expected_size +): + """Check the behaviour of `subsample`.""" + if class_method == "from_estimator": + display = PredictionErrorDisplay.from_estimator( + regressor_fitted, X, y, subsample=subsample + ) + else: + y_pred = regressor_fitted.predict(X) + display = PredictionErrorDisplay.from_predictions( + y_true=y, y_pred=y_pred, subsample=subsample + ) + assert len(display.scatter_.get_offsets()) == expected_size + + +@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"]) +def test_plot_prediction_error_ax(pyplot, regressor_fitted, class_method): + """Check that we can pass an axis to the display.""" + _, ax = pyplot.subplots() + if class_method == "from_estimator": + display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y, ax=ax) + else: + y_pred = regressor_fitted.predict(X) + display = PredictionErrorDisplay.from_predictions( + y_true=y, y_pred=y_pred, ax=ax + ) + assert display.ax_ is ax + + +@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"]) +def test_prediction_error_custom_artist(pyplot, regressor_fitted, class_method): + """Check that we can tune the style of the lines.""" + extra_params = { + "kind": "actual_vs_predicted", + "scatter_kwargs": {"color": "red"}, + "line_kwargs": {"color": "black"}, + } + if class_method == "from_estimator": + display = PredictionErrorDisplay.from_estimator( + regressor_fitted, X, y, **extra_params + ) + else: + y_pred = regressor_fitted.predict(X) + display = PredictionErrorDisplay.from_predictions( + y_true=y, y_pred=y_pred, **extra_params + ) + + assert display.line_.get_color() == "black" + assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]]) + + # create a display with the default values + if class_method == "from_estimator": + display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y) + else: + y_pred = regressor_fitted.predict(X) + display = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred) + pyplot.close("all") + + display.plot(**extra_params) + assert display.line_.get_color() == "black" + assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]])