Skip to content

Commit

Permalink
FIX use same API for CalibrationDisplay than other Display (scikit-le…
Browse files Browse the repository at this point in the history
…arn#21031)

* FIX use same API for CalibrationDisplay than other Display

* Update sklearn/calibration.py

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* iter

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
2 people authored and samronsin committed Nov 30, 2021
1 parent b5d2613 commit 59c21dd
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
19 changes: 11 additions & 8 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ class CalibrationDisplay:
y_prob : ndarray of shape (n_samples,)
Probability estimates for the positive class, for each sample.
name : str, default=None
Name for labeling curve.
estimator_name : str, default=None
Name of estimator. If None, the estimator name is not shown.
Attributes
----------
Expand Down Expand Up @@ -1022,11 +1022,11 @@ class CalibrationDisplay:
<...>
"""

def __init__(self, prob_true, prob_pred, y_prob, *, name=None):
def __init__(self, prob_true, prob_pred, y_prob, *, estimator_name=None):
self.prob_true = prob_true
self.prob_pred = prob_pred
self.y_prob = y_prob
self.name = name
self.estimator_name = estimator_name

def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
"""Plot visualization.
Expand All @@ -1041,7 +1041,8 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
created.
name : str, default=None
Name for labeling curve.
Name for labeling curve. If `None`, use `estimator_name` if
not `None`, otherwise no labeling is shown.
ref_line : bool, default=True
If `True`, plots a reference line representing a perfectly
Expand All @@ -1061,8 +1062,7 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
if ax is None:
fig, ax = plt.subplots()

name = self.name if name is None else name
self.name = name
name = self.estimator_name if name is None else name

line_kwargs = {}
if name is not None:
Expand Down Expand Up @@ -1298,6 +1298,9 @@ def from_predictions(
prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy=strategy
)
name = name if name is not None else "Classifier"

disp = cls(prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, name=name)
disp = cls(
prob_true=prob_true, prob_pred=prob_pred, y_prob=y_prob, estimator_name=name
)
return disp.plot(ax=ax, ref_line=ref_line, **kwargs)
4 changes: 2 additions & 2 deletions sklearn/metrics/_plot/det_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
created.
name : str, default=None
Name of DET curve for labeling. If `None`, use the name of the
estimator.
Name of DET curve for labeling. If `None`, use `estimator_name` if
it is not `None`, otherwise no labeling is shown.
**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
created.
name : str, default=None
Name of precision recall curve for labeling. If `None`, use the
name of the estimator.
Name of precision recall curve for labeling. If `None`, use
`estimator_name` if not `None`, otherwise no labeling is shown.
**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def plot(self, ax=None, *, name=None, **kwargs):
created.
name : str, default=None
Name of ROC Curve for labeling. If `None`, use the name of the
estimator.
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
not `None`, otherwise no labeling is shown.
Returns
-------
Expand Down
15 changes: 7 additions & 8 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def test_calibration_display_compute(pyplot, iris_data_binary, n_bins, strategy)
assert_allclose(viz.prob_pred, prob_pred)
assert_allclose(viz.y_prob, y_prob)

assert viz.name == "LogisticRegression"
assert viz.estimator_name == "LogisticRegression"

# cannot fail thanks to pyplot fixture
import matplotlib as mpl # noqa
Expand All @@ -715,7 +715,7 @@ def test_plot_calibration_curve_pipeline(pyplot, iris_data_binary):
clf.fit(X, y)
viz = CalibrationDisplay.from_estimator(clf, X, y)
assert clf.__class__.__name__ in viz.line_.get_label()
assert viz.name == clf.__class__.__name__
assert viz.estimator_name == clf.__class__.__name__


@pytest.mark.parametrize(
Expand All @@ -726,24 +726,23 @@ def test_calibration_display_default_labels(pyplot, name, expected_label):
prob_pred = np.array([0.2, 0.8, 0.8, 0.4])
y_prob = np.array([])

viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name)
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
viz.plot()
assert viz.line_.get_label() == expected_label


def test_calibration_display_label_class_plot(pyplot):
# Checks that when instantiating `CalibrationDisplay` class then calling
# `plot`, `self.name` is the one given in `plot`
# `plot`, `self.estimator_name` is the one given in `plot`
prob_true = np.array([0, 1, 1, 0])
prob_pred = np.array([0.2, 0.8, 0.8, 0.4])
y_prob = np.array([])

name = "name one"
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, name=name)
assert viz.name == name
viz = CalibrationDisplay(prob_true, prob_pred, y_prob, estimator_name=name)
assert viz.estimator_name == name
name = "name two"
viz.plot(name=name)
assert viz.name == name
assert viz.line_.get_label() == name


Expand All @@ -764,7 +763,7 @@ def test_calibration_display_name_multiple_calls(
params = (clf, X, y) if constructor_name == "from_estimator" else (y, y_prob)

viz = constructor(*params, name=clf_name)
assert viz.name == clf_name
assert viz.estimator_name == clf_name
pyplot.close("all")
viz.plot()
assert clf_name == viz.line_.get_label()
Expand Down

0 comments on commit 59c21dd

Please sign in to comment.