Skip to content

Commit

Permalink
FIX add pos_label when computing AP in plot_precision_recall_curve (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and adrinjalali committed Dec 2, 2019
1 parent 0a56df6 commit 4246521
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions sklearn/metrics/_plot/precision_recall_curve.py
Expand Up @@ -161,6 +161,7 @@ def plot_precision_recall_curve(estimator, X, y,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
viz = PrecisionRecallDisplay(precision, recall, average_precision,
estimator.__class__.__name__)
Expand Down
21 changes: 21 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_precision_recall.py
Expand Up @@ -7,6 +7,7 @@
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.datasets import make_classification
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -132,3 +133,23 @@ def test_precision_recall_curve_pipeline(pyplot, clf):
clf.fit(X, y)
disp = plot_precision_recall_curve(clf, X, y)
assert disp.estimator_name == clf.__class__.__name__


def test_precision_recall_curve_string_labels(pyplot):
# regression test #15738
cancer = load_breast_cancer()
X = cancer.data
y = cancer.target_names[cancer.target]

lr = make_pipeline(StandardScaler(), LogisticRegression())
lr.fit(X, y)
for klass in cancer.target_names:
assert klass in lr.classes_
disp = plot_precision_recall_curve(lr, X, y)

y_pred = lr.predict_proba(X)[:, 1]
avg_prec = average_precision_score(y, y_pred,
pos_label=lr.classes_[1])

assert disp.average_precision == pytest.approx(avg_prec)
assert disp.estimator_name == lr.__class__.__name__

0 comments on commit 4246521

Please sign in to comment.