-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
/
precision_recall_curve.py
168 lines (126 loc) · 5.3 KB
/
precision_recall_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from .base import _check_classifer_response_method
from .. import average_precision_score
from .. import precision_recall_curve
from ...utils import check_matplotlib_support
from ...base import is_classifier
class PrecisionRecallDisplay:
"""Precision Recall visualization.
It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
to create a visualizer. All parameters are stored as attributes.
Read more in the :ref:`User Guide <visualizations>`.
Parameters
-----------
precision : ndarray
Precision values.
recall : ndarray
Recall values.
average_precision : float
Average precision.
estimator_name : str
Name of estimator.
Attributes
----------
line_ : matplotlib Artist
Precision recall curve.
ax_ : matplotlib Axes
Axes with precision recall curve.
figure_ : matplotlib Figure
Figure containing the curve.
"""
def __init__(self, precision, recall, average_precision, estimator_name):
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name
def plot(self, ax=None, name=None, **kwargs):
"""Plot visualization.
Extra keyword arguments will be passed to matplotlib's `plot`.
Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.
name : str, default=None
Name of precision recall curve for labeling. If `None`, use the
name of the estimator.
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots()
name = self.estimator_name if name is None else name
line_kwargs = {
"label": "{} (AP = {:0.2f})".format(name,
self.average_precision),
"drawstyle": "steps-post"
}
line_kwargs.update(**kwargs)
self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision")
ax.legend(loc='lower left')
self.ax_ = ax
self.figure_ = ax.figure
return self
def plot_precision_recall_curve(estimator, X, y,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Precision Recall Curve for binary classifers.
Extra keyword arguments will be passed to matplotlib's `plot`.
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
estimator : estimator instance
Trained classifier.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.
y : array-like of shape (n_samples,)
Binary target values.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
response_method : {'predict_proba', 'decision_function', 'auto'}, \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.
name : str, default=None
Name for labeling curve. If `None`, the name of the
estimator is used.
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("plot_precision_recall_curve")
classification_error = ("{} should be a binary classifer".format(
estimator.__class__.__name__))
if not is_classifier(estimator):
raise ValueError(classification_error)
prediction_method = _check_classifer_response_method(estimator,
response_method)
y_pred = prediction_method(X)
if y_pred.ndim != 1:
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
y_pred = y_pred[:, 1]
pos_label = estimator.classes_[1]
precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
viz = PrecisionRecallDisplay(precision, recall, average_precision,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)