Skip to content

Commit

Permalink
Improve confusion matrix plot (#5273)
Browse files Browse the repository at this point in the history
* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Jan 17, 2022
1 parent 00921fc commit 847eb6b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
16 changes: 12 additions & 4 deletions mlflow/models/evaluation/default_evaluator.py
Expand Up @@ -599,10 +599,18 @@ def _evaluate_classifier(self):
)

def plot_confusion_matrix():
sk_metrics.ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix,
display_labels=self.label_list,
).plot(cmap="Blues")
import matplotlib

with matplotlib.rc_context(
{
"font.size": min(10, 50.0 / self.num_classes),
"axes.labelsize": 10,
}
):
sk_metrics.ConfusionMatrixDisplay(
confusion_matrix=confusion_matrix,
display_labels=self.label_list,
).plot(cmap="Blues")

if hasattr(sk_metrics, "ConfusionMatrixDisplay"):
self._log_image_artifact(
Expand Down
14 changes: 13 additions & 1 deletion mlflow/sklearn/utils.py
Expand Up @@ -288,10 +288,22 @@ def _get_classifier_artifacts(fitted_estimator, prefix, X, y_true, sample_weight
if not _is_plotting_supported():
return []

def plot_confusion_matrix(*args, **kwargs):
import matplotlib

num_classes = len(set(y_true))
with matplotlib.rc_context(
{
"font.size": min(10.0, 50.0 / num_classes),
"axes.labelsize": 10.0,
}
):
return sklearn.metrics.plot_confusion_matrix(*args, **kwargs)

classifier_artifacts = [
_SklearnArtifact(
name=prefix + "confusion_matrix",
function=sklearn.metrics.plot_confusion_matrix,
function=plot_confusion_matrix,
arguments=dict(
estimator=fitted_estimator,
X=X,
Expand Down

0 comments on commit 847eb6b

Please sign in to comment.