Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Allow plot_confusion_matrix to be called on predicted labels #15883

Closed
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 25 additions & 11 deletions sklearn/metrics/_plot/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import confusion_matrix
from ...utils import check_matplotlib_support
from ...utils.validation import check_consistent_length
from ...base import is_classifier


Expand Down Expand Up @@ -117,26 +118,29 @@ def plot(self, include_values=True, cmap='viridis',
return self


def plot_confusion_matrix(estimator, X, y_true, labels=None,
def plot_confusion_matrix(estimator=None, X=None, y_true=None, labels=None,
sample_weight=None, normalize=None,
display_labels=None, include_values=True,
xticks_rotation='horizontal',
values_format=None,
cmap='viridis', ax=None):
cmap='viridis', ax=None, y_pred=None):
"""Plot Confusion Matrix.

Read more in the :ref:`User Guide <confusion_matrix>`.

Parameters
----------
estimator : estimator instance
Trained classifier.
estimator : estimator instance, default=None
Trained classifier. Either `estimator` and `X` must be specified, or
jhennrich marked this conversation as resolved.
Show resolved Hide resolved
`y_pred`.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
X : {array-like, sparse matrix} of shape (n_samples, n_features),
default=None
jhennrich marked this conversation as resolved.
Show resolved Hide resolved
Input values. Either `estimator` and `X` must be specified, or
jhennrich marked this conversation as resolved.
Show resolved Hide resolved
`y_pred`.

y : array-like of shape (n_samples,)
Target values.
y_true : array-like of shape (n_samples,), default=None
jhennrich marked this conversation as resolved.
Show resolved Hide resolved
Target values, must be specified.

labels : array-like of shape (n_classes,), default=None
List of labels to index the matrix. This may be used to reorder or
Expand Down Expand Up @@ -175,20 +179,30 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None,
Axes object to plot on. If `None`, a new figure and axes is
created.

y_pred : array-like of shape (n_samples,), default=None
Predicted labels. Either `estimator` and `X` must be specified, or
jhennrich marked this conversation as resolved.
Show resolved Hide resolved
`y_pred`.

Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
"""
check_matplotlib_support("plot_confusion_matrix")

if not is_classifier(estimator):
raise ValueError("plot_confusion_matrix only supports classifiers")
if estimator is not None and X is not None and y_pred is None:
if not is_classifier(estimator):
raise ValueError("plot_confusion_matrix only supports classifiers")
y_pred = estimator.predict(X)
elif estimator is None and X is None and y_pred is not None:
check_consistent_length(y_true, y_pred)
else:
raise ValueError("Either 'estimator' and 'X' must be passed to "
"plot_confusion_matrix or 'y_pred'")
jhennrich marked this conversation as resolved.
Show resolved Hide resolved

if normalize not in {'true', 'pred', 'all', None}:
raise ValueError("normalize must be one of {'true', 'pred', "
"'all', None}")

y_pred = estimator.predict(X)
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
labels=labels, normalize=normalize)

Expand Down