New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suggestion: Remove prediction from plot_confusion_matrix and just pass predicted labels #15880
Comments
We should definitely stay backward compatible, but adding a Would you want to submit a PR @jhennrich ? |
I submitted a PR, but I think there is currently a problem with the CI so it has not passed yet. |
I agree that we should support |
Why is it impossible to keep the backward compatibility? It seems to me that the proposal in #15883 is OK |
because we do not support **kwargs in plot_confusion_matrix. @NicolasHug |
Why is kwargs a problem? |
Hmm, so there's another annoying thing, we support **kwargs in plot_roc_curve and plot_precision_recall_curve (and plot_partial_dependence), but we do not support it in plot_confusion_matrix |
if we add the new parameter before **kwargs, we can keep backward compatibility, right? |
The changes in my PR are backwards compatible and **kwargs can still be added. But I agree with @qinhanmin2014, a much much cleaner solution would be to throw out |
yes
Unfortunately that would require a deprecation cycle (unless we make it very fast in the bugfix release but I doubt it...) @thomasjpfan , any reason to pass the estimator as input instead of the predictions? |
Thanks, let's add y_pred first, **kwags is another issue. |
This seems impossible, sigh
I agree that we need to reconsider our API design. also try to ping @amueller |
If a user wants to provide their own plotting part and provide their own confusion matrix: from sklearn.metrics import ConfusionMatrixDisplay
confusion_matrix = confusion_matrix(...)
display_labels = [...]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix,
display_labels=display_labels)
disp.plot(...) This can similar be done for the other metric plotting functions. The |
By accepting the estimator first, there is a uniform interface for the plotting functions. For example, the |
Although, I am open to having a "fast path", allowing for |
The computation of the predictions needed to build a PDPs are quite complex. Also, these predictions are typically unusable in e.g. a scorer or a metric. They're only useful for plotting the PDP. So it makes sense in this case to only accept the estimator in plot_partial_dependence. OTOH for confusion matrix, the predictions are really just I don't think we want a uniform interface here. These are 2 very different input use-cases EDIT: In addition, the tree-based PDPs don't even need predictions at all |
There are other things we will run into without the estimator. For example if This comes down to what kind of question we are answering with this API. The current interface revolves around evaluating an estimator, thus using the estimator as an argument. It is motivated by answering "how does this trained model behave with this input data?" If we accept |
It's true that in this specific case, @jhennrich you could directly be using the ConfusionMatrixDisplay. One drawback is that you need to specify @thomasjpfan do you think we could in general provide sensible defaults for the Display objects, thus still making the direct use of the Display objects practical? |
For some parameters, like |
One classic pattern for this kind of thing is defining:
but this is not very idiomatic to scikit-learn. |
I start to get confused. The goal of current API is to avoid calculating for multiple times if users want to plot for multiple times, but if we accept y_true and y_pred, users still don't need to calculate for multiple times? (I know that things are different in PDP) |
@jnothman That API is pretty nice looking! @qinhanmin2014 Passing an The difference between them is where the calculation of confusion matrix starts. One can think of pass |
So I think |
For metrics, I can see the preference toward using est = # fit estimator
plot_partial_dependence(est, X, ...)
# if plot_confusion_matrix accepts `y_true, y_pred`
y_pred = est.predict(X)
plot_confusion_matrix(y_true, y_pred, ...)
# if plot_roc_curve supports `y_true, y_score`
y_score = est.predict_proba(X)[: , 1]
plot_roc_curve(y_true, y_score, ...)
plot_precision_recall_curve(y_true, y_score, ...) Currently the API looks like: est = # fit estimator
plot_partial_dependence(est, X, ...)
plot_confusion_matrix(est, X, y, ...)
plot_roc_curve(est, X, y, ...)
# this will call `predict_proba` again
plot_precision_recall_curve(est, X, y, ...) I would prefer to have an API that supports both options (somehow). |
Yes, this is what I mean.
I think this is a practical solution. An annoying thing is that we can only add y_pred at the end (i.e., plot_confusion_matrix(estimator, X, y_true, ..., y_pred)) |
The blocker label is for release blockers (things that absolutely need to be fixed before a release), not for PR blockers |
Ahh good to know. |
I like the two-classmethods approach the most, especially the
|
Looks like there's no strong opposition to using 2 class methods, so let's do that. We'll need to:
This is a bit of work but I think we can get this done before 0.24 so that #17443 and #18020 can move forward already. Any objection @thomasjpfan @jnothman @amueller @glemaitre ? @jhennrich @pzelasko , would you be interested in submitting a PR to introduce the class methods in one of the Display objects? |
Thanks for making the decision @NicolasHug! I'll get onto #17443 (after waiting for objections) |
I have no objections. |
No objection as well. |
I will take care of the other classes then and advance my stalled PR. |
I'd love to contribute but I'm engaged in too many projects at this time. Thanks for listening to the suggestions! |
Sounds good :) |
Seems like there's been good progress here. Removing the milestone, but pinging y'all in case you think something needs to get in before the release. |
All the PRs regarding the plotting are either reviewed or written on my side (you can have a look at this column of my board https://github.com/scikit-learn/scikit-learn/projects/17#column-6455639)
Sent from my phone - sorry to be brief and potential misspell.
|
I am closing this issue because all the tasks in the original issue has been merged. |
The signature of
plot_confusion_matrix
is currently:sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)
The function takes an estimator and raw data and can not be used with already predicted labels. This has some downsides:
Suggestion: allow passing predicted labels
y_pred
toplot_confusion_matrix
that will be used instead ofestimator
andX
. In my opinion the cleanest solution would be to remove the prediction step from the function and use a signature similar to that ofaccuracy_score
, e.g.(y_true, y_pred, labels=None, sample_weight=None, ...)
. However in order to maintain backwards compatibility,y_pred
can be added as an optional keyword argument.TODO:
Introduce the class methods for the currently existing plots:
ConfusionMatrixDisplay
ENH/DEP add class method and deprecate plot function for confusion matrix #18543PrecisionRecallDisplay
API add from_estimator and from_preditions to PrecisionRecallDisplay #20552RocCurveDisplay
API add from_estimator and from_predictions to RocCurveDisplay #20569DetCurveDisplay
API deprecate plot_det_curve in favor of display class methods #19278PartialDependenceDisplay
. For this one, we don't want to introduce thefrom_predictions
classmethod because it would not make sense, we only wantfrom_estimator
.For all Display listed above, deprecate their corresponding
plot_...
function. We don't need to deprecateplot_det_curve
because it hasn't been released yet, we can just remove it.for new PRs like ENH Add CalibrationDisplay plotting class #17443 and FEA add PredictionErrorDisplay #18020 we can implement the class methods right away instead of introducing a
plot
function.The text was updated successfully, but these errors were encountered: