Skip to content

Commit

Permalink
Improve statsmodels autologging metrics (#4942)
Browse files Browse the repository at this point in the history
* init

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix lint

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix lint

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix lint

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix lint

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* emit_warning_only in autolog

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* add tests for bad metric warning

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* fix

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Oct 29, 2021
1 parent 1fecc62 commit 721f93f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 68 deletions.
149 changes: 81 additions & 68 deletions mlflow/statsmodels.py
Expand Up @@ -15,7 +15,6 @@
import os
import yaml
import logging
import numpy as np

import mlflow
from mlflow import pyfunc
Expand Down Expand Up @@ -76,6 +75,12 @@ def get_default_conda_env():
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())


_model_size_threshold_for_emitting_warning = 100 * 1024 * 1024 # 100 MB


_save_model_called_from_autolog = False


@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def save_model(
statsmodels_model,
Expand Down Expand Up @@ -140,6 +145,17 @@ def save_model(

# Save a statsmodels model
statsmodels_model.save(model_data_path, remove_data)
if _save_model_called_from_autolog and not remove_data:
saved_model_size = os.path.getsize(model_data_path)
if saved_model_size >= _model_size_threshold_for_emitting_warning:
_logger.warning(
"The fitted model is larger than "
f"{_model_size_threshold_for_emitting_warning // (1024 * 1024)} MB, "
f"saving it as artifacts is time consuming.\n"
"To reduce model size, use `mlflow.statsmodels.autolog(log_models=False)` and "
"manually log model by "
'`mlflow.statsmodels.log_model(model, remove_data=True, artifact_path="model")`'
)

pyfunc.add_to_model(
mlflow_model,
Expand Down Expand Up @@ -192,7 +208,7 @@ def log_model(
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
pip_requirements=None,
extra_pip_requirements=None,
**kwargs
**kwargs,
):
"""
Log a statsmodels model as an MLflow artifact for the current run.
Expand Down Expand Up @@ -246,7 +262,7 @@ def log_model(
remove_data=remove_data,
pip_requirements=pip_requirements,
extra_pip_requirements=extra_pip_requirements,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -323,6 +339,49 @@ class AutologHelpers:
should_autolog = True


# Currently we only autolog basic metrics
_autolog_metric_allowlist = [
"aic",
"bic",
"centered_tss",
"condition_number",
"df_model",
"df_resid",
"ess",
"f_pvalue",
"fvalue",
"llf",
"mse_model",
"mse_resid",
"mse_total",
"rsquared",
"rsquared_adj",
"scale",
"ssr",
"uncentered_tss",
]


def _get_autolog_metrics(fitted_model):
result_metrics = {}

failed_evaluating_metrics = set()
for metric in _autolog_metric_allowlist:
try:
if hasattr(fitted_model, metric):
metric_value = getattr(fitted_model, metric)
if _is_numeric(metric_value):
result_metrics[metric] = metric_value
except Exception:
failed_evaluating_metrics.add(metric)

if len(failed_evaluating_metrics) > 0:
_logger.warning(
f"Failed to autolog metrics: {', '.join(sorted(failed_evaluating_metrics))}."
)
return result_metrics


@experimental
@autologging_integration(FLAVOR_NAME)
def autolog(
Expand All @@ -336,8 +395,11 @@ def autolog(
Enables (or disables) and configures automatic logging from statsmodels to MLflow.
Logs the following:
- results metrics returned by method `fit` of any subclass of statsmodels.base.model.Model
- allowlisted metrics returned by method `fit` of any subclass of
statsmodels.base.model.Model, the allowlisted metrics including: {autolog_metric_allowlist}
- trained model.
- an html artifact which shows the model summary.
:param log_models: If ``True``, trained models are logged as MLflow model artifacts.
If ``False``, trained models are not logged.
Expand Down Expand Up @@ -420,68 +482,6 @@ def patch_class_tree(klass):
for clazz, method_name, patch_impl in patches_list:
safe_patch(FLAVOR_NAME, clazz, method_name, patch_impl, manage_run=True)

def prepend_to_keys(dictionary: dict, preffix="_"):
"""
Modifies all keys of a dictionary by adding a preffix string to all of them
and make them compliant with mlflow params & metrics naming rules.
:param dictionary:
:param preffix: a string to be prepended to existing keys, using _ as separator
:return: a new dictionary where all keys have been modified. No changes are
made to the input dictionary
"""
import re

keys = list(dictionary.keys())
d2 = {}
for k in keys:
newkey = re.sub(r"[(|)|[|\]|.]+", "_", preffix + "_" + k)
d2[newkey] = dictionary.get(k)
return d2

def results_to_dict(results):
"""
Turns a ResultsWrapper object into a python dict
:param results: instance of a ResultsWrapper returned by a call to `fit`
:return: a python dictionary with those metrics that are (a) a real number, or (b) an array
of the same length of the number of coefficients
"""
has_features = False
features = results.model.exog_names
if features is not None:
has_features = True
nfeat = len(features)

results_dict = {}
for f in dir(results):
try:
field = getattr(results, f)
# Get all fields except covariances and private ones
if (
not callable(field)
and not f.startswith("__")
and not f.startswith("_")
and not f.startswith("cov_")
):

if (
has_features
and isinstance(field, np.ndarray)
and field.ndim == 1
and field.shape[0] == nfeat
):

d = dict(zip(features, field))
renamed_keys_dict = prepend_to_keys(d, f)
results_dict.update(renamed_keys_dict)

elif _is_numeric(field):
results_dict[f] = field

except Exception:
pass

return results_dict

def wrapper_fit(original, self, *args, **kwargs):

should_autolog = False
Expand All @@ -500,13 +500,21 @@ def wrapper_fit(original, self, *args, **kwargs):
if should_autolog:
# Log the model
if get_autologging_config(FLAVOR_NAME, "log_models", True):
try_mlflow_log(log_model, model, artifact_path="model")
global _save_model_called_from_autolog
_save_model_called_from_autolog = True
try:
try_mlflow_log(log_model, model, artifact_path="model")
finally:
_save_model_called_from_autolog = False

# Log the most common metrics
if isinstance(model, statsmodels.base.wrapper.ResultsWrapper):
metrics_dict = results_to_dict(model)
metrics_dict = _get_autolog_metrics(model)
try_mlflow_log(mlflow.log_metrics, metrics_dict)

model_summary = model.summary().as_text()
try_mlflow_log(mlflow.log_text, model_summary, "model_summary.txt")

return model

finally:
Expand All @@ -515,3 +523,8 @@ def wrapper_fit(original, self, *args, **kwargs):
AutologHelpers.should_autolog = True

patch_class_tree(statsmodels.base.model.Model)


autolog.__doc__ = autolog.__doc__.format(
autolog_metric_allowlist=", ".join(_autolog_metric_allowlist)
)
69 changes: 69 additions & 0 deletions tests/statsmodels/test_statsmodels_autolog.py
@@ -1,4 +1,5 @@
import pytest
from unittest import mock
import numpy as np
from statsmodels.tsa.base.tsa_model import TimeSeriesModel
import mlflow
Expand Down Expand Up @@ -77,6 +78,74 @@ def test_statsmodels_autolog_logs_specified_params():
mlflow.end_run()


def test_statsmodels_autolog_logs_summary_artifact():
mlflow.statsmodels.autolog()
with mlflow.start_run():
model = ols_model().model
summary_path = mlflow.get_artifact_uri("model_summary.txt").replace("file://", "")
with open(summary_path, "r") as f:
saved_summary = f.read()

# don't compare the whole summary text because it includes a "Time" field which may change.
assert model.summary().as_text().split("\n")[:4] == saved_summary.split("\n")[:4]


def test_statsmodels_autolog_emit_warning_when_model_is_large():
mlflow.statsmodels.autolog()

with mock.patch(
"mlflow.statsmodels._model_size_threshold_for_emitting_warning", float("inf")
), mock.patch("mlflow.statsmodels._logger.warning") as mock_warning:
ols_model()
assert all(
not call_args[0][0].startswith("The fitted model is larger than")
for call_args in mock_warning.call_args_list
)

with mock.patch("mlflow.statsmodels._model_size_threshold_for_emitting_warning", 1), mock.patch(
"mlflow.statsmodels._logger.warning"
) as mock_warning:
ols_model()
assert any(
call_args[0][0].startswith("The fitted model is larger than")
for call_args in mock_warning.call_args_list
)


def test_statsmodels_autolog_logs_basic_metrics():
mlflow.statsmodels.autolog()
ols_model()
run = get_latest_run()
metrics = run.data.metrics
assert set(metrics.keys()) == set(mlflow.statsmodels._autolog_metric_allowlist)


def test_statsmodels_autolog_failed_metrics_warning():
mlflow.statsmodels.autolog()

@property
def metric_raise_error(_):
raise RuntimeError()

class MockSummary:
def as_text(self):
return "mock summary."

with mock.patch(
"statsmodels.regression.linear_model.OLSResults.f_pvalue", metric_raise_error
), mock.patch(
"statsmodels.regression.linear_model.OLSResults.fvalue", metric_raise_error
), mock.patch(
# Prevent `OLSResults.summary` from calling `fvalue` and `f_pvalue` that raise an exception
"statsmodels.regression.linear_model.OLSResults.summary",
return_value=MockSummary(),
), mock.patch(
"mlflow.statsmodels._logger.warning"
) as mock_warning:
ols_model()
mock_warning.assert_called_once_with("Failed to autolog metrics: f_pvalue, fvalue.")


def test_statsmodels_autolog_works_after_exception():
mlflow.statsmodels.autolog()
# We first fit a model known to raise an exception
Expand Down

0 comments on commit 721f93f

Please sign in to comment.