Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Allow plot_confusion_matrix to be called on predicted labels #15883

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 26 additions & 12 deletions sklearn/metrics/_plot/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .. import confusion_matrix
from ...utils import check_matplotlib_support
from ...utils.validation import check_consistent_length
from ...base import is_classifier


Expand Down Expand Up @@ -117,26 +118,27 @@ def plot(self, include_values=True, cmap='viridis',
return self


def plot_confusion_matrix(estimator, X, y_true, labels=None,
def plot_confusion_matrix(estimator=None, X=None, y_true=None, labels=None,
sample_weight=None, normalize=None,
display_labels=None, include_values=True,
xticks_rotation='horizontal',
values_format=None,
cmap='viridis', ax=None):
cmap='viridis', ax=None, y_pred=None):
"""Plot Confusion Matrix.

Read more in the :ref:`User Guide <confusion_matrix>`.

Parameters
----------
estimator : estimator instance
Trained classifier.
estimator : estimator instance, default=None
Trained classifier. Must be None if `y_pred` is specified.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
X : {array-like, sparse matrix} of shape (n_samples, n_features),\
default=None
Input values. Must be None if `y_pred` is specified.

y : array-like of shape (n_samples,)
Target values.
y_true : array-like of shape (n_samples,)
Target values, must be specified.

labels : array-like of shape (n_classes,), default=None
List of labels to index the matrix. This may be used to reorder or
Expand Down Expand Up @@ -175,26 +177,38 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None,
Axes object to plot on. If `None`, a new figure and axes is
created.

y_pred : array-like of shape (n_samples,), default=None
Predicted labels. Must be None if `estimator` and `X` are specified.

Returns
-------
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
"""
check_matplotlib_support("plot_confusion_matrix")

if not is_classifier(estimator):
raise ValueError("plot_confusion_matrix only supports classifiers")
if estimator is not None and X is not None and y_pred is None:
if not is_classifier(estimator):
raise ValueError("plot_confusion_matrix only supports classifiers")
y_pred = estimator.predict(X)
elif estimator is None and X is None and y_pred is not None:
check_consistent_length(y_true, y_pred)
else:
raise ValueError("Either 'estimator' and 'X' must be passed to "
"plot_confusion_matrix or 'y_pred'")
jhennrich marked this conversation as resolved.
Show resolved Hide resolved

if normalize not in {'true', 'pred', 'all', None}:
raise ValueError("normalize must be one of {'true', 'pred', "
"'all', None}")

y_pred = estimator.predict(X)
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
labels=labels, normalize=normalize)

if display_labels is None:
if labels is None:
display_labels = estimator.classes_
if estimator is None:
display_labels = sorted(set(y_true).union(y_pred))
else:
display_labels = estimator.classes_
else:
display_labels = labels

Expand Down
27 changes: 27 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,33 @@ def test_plot_confusion_matrix(pyplot, data, y_pred, n_classes, fitted_clf,
assert disp.text_ is None


def test_plot_confusion_matrix_ypred(pyplot, data, y_pred, fitted_clf):
X, y = data
cmap = 'plasma'
ax = pyplot.gca()
cm_pred = plot_confusion_matrix(y_true=y, y_pred=y_pred,
cmap=cmap, ax=ax).confusion_matrix

ax = pyplot.gca()
cm_est = plot_confusion_matrix(fitted_clf, X, y,
cmap=cmap, ax=ax).confusion_matrix

assert_allclose(cm_est, cm_pred)

err_msg = "Either 'estimator' and 'X' must be passed to " \
"plot_confusion_matrix or 'y_pred'"

with pytest.raises(ValueError) as e:
plot_confusion_matrix(fitted_clf, X, y,
cmap=cmap, ax=ax, y_pred=y_pred)
assert str(e.value) == err_msg

with pytest.raises(ValueError) as e:
plot_confusion_matrix(y, cmap=cmap, ax=ax)
assert str(e.value) == err_msg



def test_confusion_matrix_display(pyplot, data, fitted_clf, y_pred, n_classes):
X, y = data

Expand Down