Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autologging functionality for scikit-learn integration with XGBoost (Part 2) #5078

Merged
merged 15 commits into from Nov 29, 2021
48 changes: 48 additions & 0 deletions examples/xgboost_sklearn/train_sklearn.py
@@ -0,0 +1,48 @@
from pprint import pprint
jwyyy marked this conversation as resolved.
Show resolved Hide resolved

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome example! Can we add a brief README to this directory explaining what this example covers? E.g. Usage of XGBoost's scikit-learn integration with MLflow Tracking, particularly autologging?

import pandas as pd
import xgboost as xgb
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import numpy as np
import mlflow
import mlflow.xgboost

from utils import fetch_logged_data


def main():
# prepare example dataset
wine = load_wine()
X = pd.DataFrame(wine.data, columns=wine.feature_names)
y = pd.Series(wine.target)
X_train, X_test, y_train, y_test = train_test_split(X, y)

# enable auto logging
# this includes xgboost.sklearn estimators
mlflow.xgboost.autolog()

with mlflow.start_run() as run:

regressor = xgb.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3)
regressor.fit(X_train, y_train)
y_pred = regressor.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
run_id = run.info.run_id
print("Logged data and model in run {}".format(run_id))
mlflow.xgboost.log_model(regressor, artifact_path="log_model")

# show logged data
for key, data in fetch_logged_data(run.info.run_id).items():
print("\n---------- logged {} ----------".format(key))
pprint(data)

mlflow.xgboost.save_model(regressor, "trained_model/")
reload_model = mlflow.pyfunc.load_model("trained_model/")
np.testing.assert_array_almost_equal(y_pred, reload_model.predict(X_test))


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions examples/xgboost_sklearn/utils.py
@@ -0,0 +1,26 @@
import mlflow


def yield_artifacts(run_id, path=None):
"""Yield all artifacts in the specified run"""
client = mlflow.tracking.MlflowClient()
for item in client.list_artifacts(run_id, path):
if item.is_dir:
yield from yield_artifacts(run_id, item.path)
else:
yield item.path


def fetch_logged_data(run_id):
"""Fetch params, metrics, tags, and artifacts in the specified run"""
client = mlflow.tracking.MlflowClient()
data = client.get_run(run_id).data
# Exclude system tags: https://www.mlflow.org/docs/latest/tracking.html#system-tags
tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
artifacts = list(yield_artifacts(run_id))
return {
"params": data.params,
"metrics": data.metrics,
"tags": tags,
"artifacts": artifacts,
}
86 changes: 71 additions & 15 deletions mlflow/sklearn/__init__.py
Expand Up @@ -111,6 +111,19 @@ def _gen_estimators_to_patch():
]


def _gen_xgboost_sklearn_estimators_to_patch():
import xgboost as xgb

all_classes = inspect.getmembers(xgb.sklearn, inspect.isclass)
base_class = xgb.sklearn.XGBModel
sklearn_estimators = []
for _, class_object in all_classes:
if issubclass(class_object, base_class) and class_object != base_class:
sklearn_estimators.append(class_object)

return sklearn_estimators
jwyyy marked this conversation as resolved.
Show resolved Hide resolved


def get_default_pip_requirements(include_cloudpickle=False):
"""
:return: A list of default pip requirements for MLflow Models produced by this flavor.
Expand Down Expand Up @@ -371,7 +384,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
return Model.log(
Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down Expand Up @@ -1152,6 +1165,40 @@ def fetch_logged_data(run_id):
``True``. See the `post training metrics`_ section for more
details.
"""
_autolog(
flavor_name=FLAVOR_NAME,
log_input_examples=log_input_examples,
log_model_signatures=log_model_signatures,
log_models=log_models,
disable=disable,
exclusive=exclusive,
disable_for_unsupported_versions=disable_for_unsupported_versions,
silent=silent,
max_tuning_runs=max_tuning_runs,
log_post_training_metrics=log_post_training_metrics,
)


def _autolog(
flavor_name=FLAVOR_NAME,
log_input_examples=False,
log_model_signatures=True,
log_models=True,
disable=False,
exclusive=False,
disable_for_unsupported_versions=False,
silent=False,
max_tuning_runs=5,
log_post_training_metrics=True,
): # pylint: disable=unused-argument
"""
Internal autologging function for scikit-learn models.
:param flavor_name: A string value. Enable a ``mlflow.sklearn`` autologging routine
for a flavor. By default it enables autologging for original
scikit-learn models, as ``mlflow.sklearn.autolog()`` does. If
the argument is `xgboost`, autologging for XGBoost scikit-learn
models is enabled.
"""
import pandas as pd
import sklearn
import sklearn.metrics
Expand Down Expand Up @@ -1200,8 +1247,10 @@ def fit_mlflow(original, self, *args, **kwargs):
_log_pretraining_metadata(autologging_client, self, *args, **kwargs)
params_logging_future = autologging_client.flush(synchronous=False)
fit_output = original(self, *args, **kwargs)
_log_posttraining_metadata(autologging_client, self, *args, **kwargs)
autologging_client.flush(synchronous=True)
# params of xgboost sklearn models are logged in train() in mlflow.xgboost.autolog()
if flavor_name == FLAVOR_NAME:
_log_posttraining_metadata(autologging_client, self, *args, **kwargs)
autologging_client.flush(synchronous=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwyyy Instead of special casing XGBoost logic in fit_mlflow, can we define a new method called fit_mlflow_xgboost that just calls original(self, *args, **kwargs) and then logs self using mlflow.xgboost.log_model()? This will also allow us to revert changes to XGBoost autologging's train() method, since we can control how the model gets logged here.

We can then add a parameter to patched_fit (

def patched_fit(original, self, *args, **kwargs):
) to specify either fit_mlflow (for sklearn models) or fit_mlflow_xgboost (for xgboost sklearn models). Perhaps we can call this parameter fit_fn.

Let me know if you have questions here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbczumar Thank you for your suggestion! There is a small issue on model logging in the train() method when calling fit(). The fit() method calls train() internally and assigns a Booster object to the internal _Booster in a XGBoost sklearn model [see L1331]. The current train() in mlflow.xgboost.autolog() logs models before returning the model object, which means (1) the logged model is a Booster object; (2) we cannot log XGBoost sklearn models before assigning the trained Booster to _Booster. The changes in mlflow.xgboost.autolog() try to log sklearn models directly. I think we definitely can log sklearn models in fit_mlflow_xgboost() but it is extra work. Because models are logged when calling train(), and calling fit_mlflow_xgboost() just logs new information to replace old ones. However, I also think adopting your suggestion makes the code logic easier to read. Please let me know which solution sounds better to you. Thank you!

Copy link
Collaborator

@dbczumar dbczumar Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jwyyy Ah, thanks for letting me know! Can we decompose train() into two methods - one for parameter, metric, & non-model artifact logging, and one for model logging? We can then use the former method to patch xgboost.sklearn.train().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! I will make some adjustments.

params_logging_future.await_completion()
return fit_output

Expand Down Expand Up @@ -1229,10 +1278,12 @@ def _log_pretraining_metadata(
# for these seed estimators.
should_log_params_deeply = not _is_parameter_search_estimator(estimator)
run_id = mlflow.active_run().info.run_id
autologging_client.log_params(
run_id=mlflow.active_run().info.run_id,
params=estimator.get_params(deep=should_log_params_deeply),
)
# params of xgboost sklearn models are logged in train() in mlflow.xgboost.autolog()
if flavor_name == FLAVOR_NAME:
autologging_client.log_params(
run_id=mlflow.active_run().info.run_id,
params=estimator.get_params(deep=should_log_params_deeply),
)
autologging_client.set_tags(
run_id=run_id,
tags=_get_estimator_info_tags(estimator),
Expand Down Expand Up @@ -1340,7 +1391,7 @@ def _log_model_with_except_handling(*args, **kwargs):
# Fetch environment-specific tags (e.g., user and source) to ensure that lineage
# information is consistent with the parent run
child_tags = context_registry.resolve_tags()
child_tags.update({MLFLOW_AUTOLOGGING: FLAVOR_NAME})
child_tags.update({MLFLOW_AUTOLOGGING: flavor_name})
_create_child_runs_for_parameter_search(
autologging_client=autologging_client,
cv_estimator=estimator,
Expand Down Expand Up @@ -1547,11 +1598,16 @@ def out(*args, **kwargs):

_apply_sklearn_descriptor_unbound_method_call_fix()

for class_def in _gen_estimators_to_patch():
if flavor_name == mlflow.xgboost.FLAVOR_NAME:
estimators_to_patch = _gen_xgboost_sklearn_estimators_to_patch()
else:
estimators_to_patch = _gen_estimators_to_patch()

for class_def in estimators_to_patch:
# Patch fitting methods
for func_name in ["fit", "fit_transform", "fit_predict"]:
_patch_estimator_method_if_available(
FLAVOR_NAME,
flavor_name,
class_def,
func_name,
patched_fit,
Expand All @@ -1561,7 +1617,7 @@ def out(*args, **kwargs):
# Patch inference methods
for func_name in ["predict", "predict_proba", "transform", "predict_log_proba"]:
_patch_estimator_method_if_available(
FLAVOR_NAME,
flavor_name,
class_def,
func_name,
patched_predict,
Expand All @@ -1570,7 +1626,7 @@ def out(*args, **kwargs):

# Patch scoring methods
_patch_estimator_method_if_available(
FLAVOR_NAME,
flavor_name,
class_def,
"score",
patched_model_score,
Expand All @@ -1580,19 +1636,19 @@ def out(*args, **kwargs):
if log_post_training_metrics:
for metric_name in _get_metric_name_list():
safe_patch(
FLAVOR_NAME, sklearn.metrics, metric_name, patched_metric_api, manage_run=False
flavor_name, sklearn.metrics, metric_name, patched_metric_api, manage_run=False
)

for scorer in sklearn.metrics.SCORERS.values():
safe_patch(FLAVOR_NAME, scorer, "_score_func", patched_metric_api, manage_run=False)
safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False)

def patched_fn_with_autolog_disabled(original, *args, **kwargs):
with disable_autologging():
return original(*args, **kwargs)

for disable_autolog_func_name in _apis_autologging_disabled:
safe_patch(
FLAVOR_NAME,
flavor_name,
sklearn.model_selection,
disable_autolog_func_name,
patched_fn_with_autolog_disabled,
Expand Down
32 changes: 30 additions & 2 deletions mlflow/xgboost/__init__.py
Expand Up @@ -413,6 +413,15 @@ def __init__(original, self, *args, **kwargs):
original(self, *args, **kwargs)

def train(original, *args, **kwargs):
def _get_xgb_caller_info():
import inspect

xgb_caller = inspect.stack()[4]
is_caller_fit = xgb_caller[3] == "fit"
return xgb_caller[0].f_locals["self"] if is_caller_fit else None

xgb_caller = _get_xgb_caller_info()

def record_eval_results(eval_results, metrics_logger):
"""
Create a callback function that records evaluation results.
Expand All @@ -426,7 +435,7 @@ def record_eval_results(eval_results, metrics_logger):

# In xgboost >= 1.3.0, user-defined callbacks should inherit
# `xgboost.callback.TrainingCallback`:
# https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback # noqa
# https://xgboost.readthedocs.io/en/latest/python/callbacks.html#defining-your-own-callback
return AutologCallback(metrics_logger, eval_results)
else:
from mlflow.xgboost._autolog import autolog_callback
Expand Down Expand Up @@ -579,6 +588,8 @@ def log_feature_importance_plot(features, importance, importance_type):

# training model
model = original(*args, **kwargs)
if xgb_caller:
xgb_caller._Booster = model

# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
Expand Down Expand Up @@ -656,7 +667,7 @@ def infer_model_signature(input_example):
)

log_model(
model,
xgb_caller if xgb_caller else model,
artifact_path="model",
signature=signature,
input_example=input_example,
Expand All @@ -669,4 +680,21 @@ def infer_model_signature(input_example):
return model

safe_patch(FLAVOR_NAME, xgboost, "train", train, manage_run=True)
safe_patch(FLAVOR_NAME, xgboost.sklearn, "train", train, manage_run=True)
safe_patch(FLAVOR_NAME, xgboost.DMatrix, "__init__", __init__)

# enable xgboost scikit-learn estimators autologging
import mlflow.sklearn

mlflow.sklearn._autolog(
flavor_name=FLAVOR_NAME,
log_input_examples=log_input_examples,
log_model_signatures=log_model_signatures,
log_models=log_models,
disable=disable,
exclusive=exclusive,
disable_for_unsupported_versions=disable_for_unsupported_versions,
silent=silent,
max_tuning_runs=None,
log_post_training_metrics=True,
)
18 changes: 15 additions & 3 deletions tests/autologging/test_autologging_behaviors_integration.py
Expand Up @@ -90,14 +90,26 @@ def test_autologging_integrations_use_safe_patch_for_monkey_patching(integration
) as gorilla_mock, mock.patch(
integration.__name__ + ".safe_patch", wraps=safe_patch
) as safe_patch_mock:
integration.autolog(disable=False)
assert safe_patch_mock.call_count > 0
if integration.__name__ == "mlflow.xgboost":
jwyyy marked this conversation as resolved.
Show resolved Hide resolved
with mock.patch(
"mlflow.sklearn.safe_patch", wraps=safe_patch
) as xgb_sklearn_safe_patch_mock:
integration.autolog(disable=False)

jwyyy marked this conversation as resolved.
Show resolved Hide resolved
safe_patch_call_count = (
safe_patch_mock.call_count + xgb_sklearn_safe_patch_mock.call_count
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the subject of test coverage, can we add a test case to https://github.com/mlflow/mlflow/blob/master/tests/xgboost/test_xgboost_autolog.py ensuring that autologging works as expected for XGBoost scikit-learn models? Feel free to use code from your excellent example above.

integration.autolog(disable=False)
safe_patch_call_count = safe_patch_mock.call_count

assert safe_patch_call_count > 0
# `safe_patch` leverages `gorilla.apply` in its implementation. Accordingly, we expect
# that the total number of `gorilla.apply` calls to be equivalent to the number of
# `safe_patch` calls. This verifies that autologging integrations are leveraging
# `safe_patch`, rather than calling `gorilla.apply` directly (which does not provide
# exception safety properties)
assert safe_patch_mock.call_count == gorilla_mock.call_count
assert safe_patch_call_count == gorilla_mock.call_count


def test_autolog_respects_exclusive_flag(setup_sklearn_model):
Expand Down