From 42465212c1897ea8bccf259cb3197c04bed9dc02 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 29 Nov 2019 04:41:31 +0100 Subject: [PATCH] FIX add pos_label when computing AP in plot_precision_recall_curve (#15739) --- .../metrics/_plot/precision_recall_curve.py | 1 + .../_plot/tests/test_plot_precision_recall.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index d2b84059c3c0e..d515b9aa86b1d 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -161,6 +161,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=pos_label, sample_weight=sample_weight) average_precision = average_precision_score(y, y_pred, + pos_label=pos_label, sample_weight=sample_weight) viz = PrecisionRecallDisplay(precision, recall, average_precision, estimator.__class__.__name__) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 1012e13027f5a..60e06ed34ad01 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -7,6 +7,7 @@ from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.datasets import make_classification +from sklearn.datasets import load_breast_cancer from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.linear_model import LogisticRegression from sklearn.exceptions import NotFittedError @@ -132,3 +133,23 @@ def test_precision_recall_curve_pipeline(pyplot, clf): clf.fit(X, y) disp = plot_precision_recall_curve(clf, X, y) assert disp.estimator_name == clf.__class__.__name__ + + +def test_precision_recall_curve_string_labels(pyplot): + # regression test #15738 + cancer = load_breast_cancer() + X = cancer.data + y = cancer.target_names[cancer.target] + + lr = make_pipeline(StandardScaler(), LogisticRegression()) + lr.fit(X, y) + for klass in cancer.target_names: + assert klass in lr.classes_ + disp = plot_precision_recall_curve(lr, X, y) + + y_pred = lr.predict_proba(X)[:, 1] + avg_prec = average_precision_score(y, y_pred, + pos_label=lr.classes_[1]) + + assert disp.average_precision == pytest.approx(avg_prec) + assert disp.estimator_name == lr.__class__.__name__