From 174cd4cbb62ef986f655e8b6cc3d01801a53e497 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 15 Jan 2022 13:46:55 +0800 Subject: [PATCH 1/3] update Signed-off-by: Weichen Xu --- mlflow/models/evaluation/default_evaluator.py | 13 +++++++++---- mlflow/sklearn/utils.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlflow/models/evaluation/default_evaluator.py b/mlflow/models/evaluation/default_evaluator.py index dc2e6d12709c9..e40400032c630 100644 --- a/mlflow/models/evaluation/default_evaluator.py +++ b/mlflow/models/evaluation/default_evaluator.py @@ -599,10 +599,15 @@ def _evaluate_classifier(self): ) def plot_confusion_matrix(): - sk_metrics.ConfusionMatrixDisplay( - confusion_matrix=confusion_matrix, - display_labels=self.label_list, - ).plot(cmap="Blues") + import matplotlib + with matplotlib.rc_context({ + 'font.size': min(10, 50.0 / self.num_classes), + 'axes.labelsize': 10, + }): + sk_metrics.ConfusionMatrixDisplay( + confusion_matrix=confusion_matrix, + display_labels=self.label_list, + ).plot(cmap="Blues") if hasattr(sk_metrics, "ConfusionMatrixDisplay"): self._log_image_artifact( diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index 1994cdc6d2ad6..a6452ff4adc74 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -288,10 +288,19 @@ def _get_classifier_artifacts(fitted_estimator, prefix, X, y_true, sample_weight if not _is_plotting_supported(): return [] + def plot_confusion_matrix(*args, **kwargs): + import matplotlib + num_classes = len(set(y_true)) + with matplotlib.rc_context({ + 'font.size': min(10.0, 50.0 / num_classes), + 'axes.labelsize': 10.0, + }): + sklearn.metrics.plot_confusion_matrix(*args, **kwargs) + classifier_artifacts = [ _SklearnArtifact( name=prefix + "confusion_matrix", - function=sklearn.metrics.plot_confusion_matrix, + function=plot_confusion_matrix, arguments=dict( estimator=fitted_estimator, X=X, From 618e979d1505d93ccb18b8c0154b8be61f4c47cc Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Sat, 15 Jan 2022 13:54:34 +0800 Subject: [PATCH 2/3] fix Signed-off-by: Weichen Xu --- mlflow/sklearn/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index a6452ff4adc74..4dab4cb9bda66 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -295,7 +295,7 @@ def plot_confusion_matrix(*args, **kwargs): 'font.size': min(10.0, 50.0 / num_classes), 'axes.labelsize': 10.0, }): - sklearn.metrics.plot_confusion_matrix(*args, **kwargs) + return sklearn.metrics.plot_confusion_matrix(*args, **kwargs) classifier_artifacts = [ _SklearnArtifact( From a8bb4baeb6cde9dab9a82212201206b964af0bd2 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 17 Jan 2022 10:45:29 +0800 Subject: [PATCH 3/3] update Signed-off-by: Weichen Xu --- mlflow/models/evaluation/default_evaluator.py | 11 +++++++---- mlflow/sklearn/utils.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mlflow/models/evaluation/default_evaluator.py b/mlflow/models/evaluation/default_evaluator.py index e40400032c630..d80335d98f039 100644 --- a/mlflow/models/evaluation/default_evaluator.py +++ b/mlflow/models/evaluation/default_evaluator.py @@ -600,10 +600,13 @@ def _evaluate_classifier(self): def plot_confusion_matrix(): import matplotlib - with matplotlib.rc_context({ - 'font.size': min(10, 50.0 / self.num_classes), - 'axes.labelsize': 10, - }): + + with matplotlib.rc_context( + { + "font.size": min(10, 50.0 / self.num_classes), + "axes.labelsize": 10, + } + ): sk_metrics.ConfusionMatrixDisplay( confusion_matrix=confusion_matrix, display_labels=self.label_list, diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index 4dab4cb9bda66..6db14eed33783 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -290,11 +290,14 @@ def _get_classifier_artifacts(fitted_estimator, prefix, X, y_true, sample_weight def plot_confusion_matrix(*args, **kwargs): import matplotlib + num_classes = len(set(y_true)) - with matplotlib.rc_context({ - 'font.size': min(10.0, 50.0 / num_classes), - 'axes.labelsize': 10.0, - }): + with matplotlib.rc_context( + { + "font.size": min(10.0, 50.0 / num_classes), + "axes.labelsize": 10.0, + } + ): return sklearn.metrics.plot_confusion_matrix(*args, **kwargs) classifier_artifacts = [