From ee9532aba858a140d26254a7799469137340fd9b Mon Sep 17 00:00:00 2001 From: Junwen Yao Date: Fri, 14 Jan 2022 12:35:19 -0800 Subject: [PATCH] Autologging functionality for scikit-learn integration with LightGBM (Part 2) (#5200) * init commit, to-do: examples Signed-off-by: Junwen Yao * add examples, update doc Signed-off-by: Junwen Yao * re-start example test Signed-off-by: Junwen Yao * update Signed-off-by: Junwen Yao * check sagemaker Signed-off-by: Junwen Yao * [resolve conflict] update Signed-off-by: Junwen Yao --- examples/lightgbm/README.md | 25 +---------- .../lightgbm/{ => lightgbm_native}/MLproject | 0 examples/lightgbm/lightgbm_native/README.md | 25 +++++++++++ .../lightgbm/{ => lightgbm_native}/conda.yaml | 0 .../lightgbm/{ => lightgbm_native}/train.py | 0 examples/lightgbm/lightgbm_sklearn/MLproject | 5 +++ examples/lightgbm/lightgbm_sklearn/README.md | 11 +++++ examples/lightgbm/lightgbm_sklearn/conda.yaml | 11 +++++ examples/lightgbm/lightgbm_sklearn/train.py | 39 +++++++++++++++++ examples/lightgbm/lightgbm_sklearn/utils.py | 26 +++++++++++ mlflow/lightgbm.py | 43 +++++++++++++++---- mlflow/sklearn/__init__.py | 21 ++++++--- mlflow/sklearn/utils.py | 20 ++++++++- .../test_autologging_behaviors_integration.py | 13 +++--- tests/examples/test_examples.py | 6 ++- tests/lightgbm/test_lightgbm_autolog.py | 29 +++++++++++++ 16 files changed, 228 insertions(+), 46 deletions(-) rename examples/lightgbm/{ => lightgbm_native}/MLproject (100%) create mode 100644 examples/lightgbm/lightgbm_native/README.md rename examples/lightgbm/{ => lightgbm_native}/conda.yaml (100%) rename examples/lightgbm/{ => lightgbm_native}/train.py (100%) create mode 100644 examples/lightgbm/lightgbm_sklearn/MLproject create mode 100644 examples/lightgbm/lightgbm_sklearn/README.md create mode 100644 examples/lightgbm/lightgbm_sklearn/conda.yaml create mode 100644 examples/lightgbm/lightgbm_sklearn/train.py create mode 100644 examples/lightgbm/lightgbm_sklearn/utils.py diff --git a/examples/lightgbm/README.md b/examples/lightgbm/README.md index bacbf7a1c33e8..4a1c43e22dff5 100644 --- a/examples/lightgbm/README.md +++ b/examples/lightgbm/README.md @@ -1,25 +1,4 @@ -# LightGBM Example -This example trains a LightGBM classifier with the iris dataset and logs hyperparameters, metrics, and trained model. +# Examples for LightGBM Autologging -## Running the code - -``` -python train.py --colsample-bytree 0.8 --subsample 0.9 -``` -You can try experimenting with different parameter values like: -``` -python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 -``` - -Then you can open the MLflow UI to track the experiments and compare your runs via: -``` -mlflow ui -``` - -## Running the code as a project - -``` -mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 - -``` +LightGBM autologging functionalities are demonstrated through two examples. The first example in the `lightgbm_native` folder logs a Booster model trained by `xgboost.train()`. The second example in the `lightgbm_sklearn` folder shows how autologging works for LightGBM scikit-learn models. The autologging for all LightGBM models is enabled via `mlflow.xgboost.autolog()`. diff --git a/examples/lightgbm/MLproject b/examples/lightgbm/lightgbm_native/MLproject similarity index 100% rename from examples/lightgbm/MLproject rename to examples/lightgbm/lightgbm_native/MLproject diff --git a/examples/lightgbm/lightgbm_native/README.md b/examples/lightgbm/lightgbm_native/README.md new file mode 100644 index 0000000000000..bacbf7a1c33e8 --- /dev/null +++ b/examples/lightgbm/lightgbm_native/README.md @@ -0,0 +1,25 @@ +# LightGBM Example + +This example trains a LightGBM classifier with the iris dataset and logs hyperparameters, metrics, and trained model. + +## Running the code + +``` +python train.py --colsample-bytree 0.8 --subsample 0.9 +``` +You can try experimenting with different parameter values like: +``` +python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 +``` + +Then you can open the MLflow UI to track the experiments and compare your runs via: +``` +mlflow ui +``` + +## Running the code as a project + +``` +mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 + +``` diff --git a/examples/lightgbm/conda.yaml b/examples/lightgbm/lightgbm_native/conda.yaml similarity index 100% rename from examples/lightgbm/conda.yaml rename to examples/lightgbm/lightgbm_native/conda.yaml diff --git a/examples/lightgbm/train.py b/examples/lightgbm/lightgbm_native/train.py similarity index 100% rename from examples/lightgbm/train.py rename to examples/lightgbm/lightgbm_native/train.py diff --git a/examples/lightgbm/lightgbm_sklearn/MLproject b/examples/lightgbm/lightgbm_sklearn/MLproject new file mode 100644 index 0000000000000..77d8707bf4e94 --- /dev/null +++ b/examples/lightgbm/lightgbm_sklearn/MLproject @@ -0,0 +1,5 @@ +name: lightgbm-sklearn-example +conda_env: conda.yaml +entry_points: + main: + command: python train.py diff --git a/examples/lightgbm/lightgbm_sklearn/README.md b/examples/lightgbm/lightgbm_sklearn/README.md new file mode 100644 index 0000000000000..220b3a1458a33 --- /dev/null +++ b/examples/lightgbm/lightgbm_sklearn/README.md @@ -0,0 +1,11 @@ +# XGBoost Scikit-learn Model Example + +This example trains an [`LightGBM.LGBMClassifier`](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html) with the diabetes dataset and logs hyperparameters, metrics, and trained model. + +Like the other LightGBM example, we enable autologging for LightGBM scikit-learn models via `mlflow.lightgbm.autolog()`. Saving / loading models also supports LightGBM scikit-learn models. + +You can run this example using the following command: + +``` python +python train.py +``` diff --git a/examples/lightgbm/lightgbm_sklearn/conda.yaml b/examples/lightgbm/lightgbm_sklearn/conda.yaml new file mode 100644 index 0000000000000..d79fa96855c83 --- /dev/null +++ b/examples/lightgbm/lightgbm_sklearn/conda.yaml @@ -0,0 +1,11 @@ +name: lightgbm-example +channels: + - conda-forge +dependencies: + - python=3.6 + - pip + - pip: + - mlflow>=1.6.0 + - matplotlib + - lightgbm + - cloudpickle>=2.0.0 diff --git a/examples/lightgbm/lightgbm_sklearn/train.py b/examples/lightgbm/lightgbm_sklearn/train.py new file mode 100644 index 0000000000000..e301aaa8f8120 --- /dev/null +++ b/examples/lightgbm/lightgbm_sklearn/train.py @@ -0,0 +1,39 @@ +from pprint import pprint + +import lightgbm as lgb +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.metrics import f1_score + +import mlflow +import mlflow.lightgbm + +from utils import fetch_logged_data + + +def main(): + # prepare example dataset + X, y = load_iris(return_X_y=True, as_frame=True) + X_train, X_test, y_train, y_test = train_test_split(X, y) + + # enable auto logging + # this includes lightgbm.sklearn estimators + mlflow.lightgbm.autolog() + + with mlflow.start_run() as run: + + regressor = lgb.LGBMClassifier(n_estimators=20, reg_lambda=1.0) + regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)]) + y_pred = regressor.predict(X_test) + f1 = f1_score(y_test, y_pred, average="micro") + run_id = run.info.run_id + print("Logged data and model in run {}".format(run_id)) + + # show logged data + for key, data in fetch_logged_data(run.info.run_id).items(): + print("\n---------- logged {} ----------".format(key)) + pprint(data) + + +if __name__ == "__main__": + main() diff --git a/examples/lightgbm/lightgbm_sklearn/utils.py b/examples/lightgbm/lightgbm_sklearn/utils.py new file mode 100644 index 0000000000000..00270bb395935 --- /dev/null +++ b/examples/lightgbm/lightgbm_sklearn/utils.py @@ -0,0 +1,26 @@ +import mlflow + + +def yield_artifacts(run_id, path=None): + """Yield all artifacts in the specified run""" + client = mlflow.tracking.MlflowClient() + for item in client.list_artifacts(run_id, path): + if item.is_dir: + yield from yield_artifacts(run_id, item.path) + else: + yield item.path + + +def fetch_logged_data(run_id): + """Fetch params, metrics, tags, and artifacts in the specified run""" + client = mlflow.tracking.MlflowClient() + data = client.get_run(run_id).data + # Exclude system tags: https://www.mlflow.org/docs/latest/tracking.html#system-tags + tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} + artifacts = list(yield_artifacts(run_id)) + return { + "params": data.params, + "metrics": data.metrics, + "tags": tags, + "artifacts": artifacts, + } diff --git a/mlflow/lightgbm.py b/mlflow/lightgbm.py index 739b263a062de..f43c63fa4d4b1 100644 --- a/mlflow/lightgbm.py +++ b/mlflow/lightgbm.py @@ -103,8 +103,8 @@ def save_model( """ Save a LightGBM model to a path on the local file system. - :param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) to be saved. - Note that models that implement the `scikit-learn API`_ are not supported. + :param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) or + models that implement the `scikit-learn API`_ to be saved. :param path: Local path where the model is to be saved. :param conda_env: {{ conda_env }} :param mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to. @@ -231,8 +231,8 @@ def log_model( """ Log a LightGBM model as an MLflow artifact for the current run. - :param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) to be saved. - Note that models that implement the `scikit-learn API`_ are not supported. + :param lgb_model: LightGBM model (an instance of `lightgbm.Booster`_) or + models that implement the `scikit-learn API`_ to be saved. :param artifact_path: Run-relative artifact path. :param conda_env: {{ conda_env }} :param registered_model_name: If given, create a model version under @@ -382,7 +382,7 @@ def autolog( - an example of valid input. - inferred signature of the inputs and outputs of the model. - Note that the `scikit-learn API`_ is not supported. + Note that the `scikit-learn API`_ is now supported. :param log_input_examples: If ``True``, input examples from training datasets are collected and logged along with LightGBM model artifacts during training. If @@ -439,7 +439,7 @@ def __init__(original, self, *args, **kwargs): original(self, *args, **kwargs) - def train(original, *args, **kwargs): + def train(_log_models, original, *args, **kwargs): def record_eval_results(eval_results, metrics_logger): """ Create a callback function that records evaluation results. @@ -602,7 +602,7 @@ def infer_model_signature(input_example): return model_signature # Whether to automatically log the trained model based on boolean flag. - if log_models: + if _log_models: # Will only resolve `input_example` and `signature` if `log_models` is `True`. input_example, signature = resolve_input_example_and_signature( get_input_example, @@ -625,5 +625,32 @@ def infer_model_signature(input_example): return model - safe_patch(FLAVOR_NAME, lightgbm, "train", train, manage_run=True) safe_patch(FLAVOR_NAME, lightgbm.Dataset, "__init__", __init__) + safe_patch( + FLAVOR_NAME, lightgbm, "train", functools.partial(train, log_models), manage_run=True + ) + # The `train()` method logs LightGBM models as Booster objects. When using LightGBM + # scikit-learn models, we want to save / log models as their model classes. So we turn + # off the log_models functionality in the `train()` method patched to `lightgbm.sklearn`. + # Instead the model logging is handled in `fit_mlflow_xgboost_and_lightgbm()` + # in `mlflow.sklearn._autolog()`, where models are logged as LightGBM scikit-learn models + # after the `fit()` method returns. + safe_patch( + FLAVOR_NAME, lightgbm.sklearn, "train", functools.partial(train, False), manage_run=True + ) + + # enable LightGBM scikit-learn estimators autologging + import mlflow.sklearn + + mlflow.sklearn._autolog( + flavor_name=FLAVOR_NAME, + log_input_examples=log_input_examples, + log_model_signatures=log_model_signatures, + log_models=log_models, + disable=disable, + exclusive=exclusive, + disable_for_unsupported_versions=disable_for_unsupported_versions, + silent=silent, + max_tuning_runs=None, + log_post_training_metrics=True, + ) diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py index 955555f1ec049..6d881f545a029 100644 --- a/mlflow/sklearn/__init__.py +++ b/mlflow/sklearn/__init__.py @@ -1200,6 +1200,7 @@ def _autolog( _is_supported_version, _get_X_y_and_sample_weight, _gen_xgboost_sklearn_estimators_to_patch, + _gen_lightgbm_sklearn_estimators_to_patch, _log_estimator_content, _all_estimators, _get_estimator_info_tags, @@ -1227,12 +1228,12 @@ def _autolog( stacklevel=2, ) - def fit_mlflow_xgboost(original, self, *args, **kwargs): + def fit_mlflow_xgboost_and_lightgbm(original, self, *args, **kwargs): """ - Autologging function for XGBoost scikit-learn models + Autologging function for XGBoost and LightGBM scikit-learn models """ - # parameter, metric, and non-model artifact logging - # are done in `train()` in `mlflow.xgboost.autolog()` + # parameter, metric, and non-model artifact logging are done in + # `train()` in `mlflow.xgboost.autolog()` and `mlflow.lightgbm.autolog()` fit_output = original(self, *args, **kwargs) # log models after training X = _get_X_y_and_sample_weight(self.fit, args, kwargs)[0] @@ -1244,7 +1245,12 @@ def fit_mlflow_xgboost(original, self, *args, **kwargs): log_model_signatures, _logger, ) - mlflow.xgboost.log_model( + log_model_func = ( + mlflow.xgboost.log_model + if flavor_name == mlflow.xgboost.FLAVOR_NAME + else mlflow.lightgbm.log_model + ) + log_model_func( self, artifact_path="model", signature=signature, @@ -1611,7 +1617,10 @@ def out(*args, **kwargs): if flavor_name == mlflow.xgboost.FLAVOR_NAME: estimators_to_patch = _gen_xgboost_sklearn_estimators_to_patch() - patched_fit_impl = fit_mlflow_xgboost + patched_fit_impl = fit_mlflow_xgboost_and_lightgbm + elif flavor_name == mlflow.lightgbm.FLAVOR_NAME: + estimators_to_patch = _gen_lightgbm_sklearn_estimators_to_patch() + patched_fit_impl = fit_mlflow_xgboost_and_lightgbm else: estimators_to_patch = _gen_estimators_to_patch() patched_fit_impl = fit_mlflow diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index e77fb2755b284..1994cdc6d2ad6 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -47,6 +47,25 @@ def _gen_xgboost_sklearn_estimators_to_patch(): return sklearn_estimators +def _gen_lightgbm_sklearn_estimators_to_patch(): + import mlflow.lightgbm + import lightgbm as lgb + + all_classes = inspect.getmembers(lgb.sklearn, inspect.isclass) + base_class = lgb.sklearn._LGBMModelBase + sklearn_estimators = [] + for _, class_object in all_classes: + package_name = class_object.__module__.split(".")[0] + if ( + package_name == mlflow.lightgbm.FLAVOR_NAME + and issubclass(class_object, base_class) + and class_object != base_class + ): + sklearn_estimators.append(class_object) + + return sklearn_estimators + + def _get_estimator_info_tags(estimator): """ :return: A dictionary of MLflow run tag keys and values @@ -102,7 +121,6 @@ def _get_sample_weight(arg_names, args, kwargs): return None fit_arg_names = _get_arg_names(fit_func) - # In most cases, X_var_name and y_var_name become "X" and "y", respectively. # However, certain sklearn models use different variable names for X and y. # E.g., see: https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html#sklearn.multioutput.MultiOutputClassifier.fit diff --git a/tests/autologging/test_autologging_behaviors_integration.py b/tests/autologging/test_autologging_behaviors_integration.py index f474cf18def02..448ae1b971dbb 100644 --- a/tests/autologging/test_autologging_behaviors_integration.py +++ b/tests/autologging/test_autologging_behaviors_integration.py @@ -90,17 +90,18 @@ def test_autologging_integrations_use_safe_patch_for_monkey_patching(integration ) as gorilla_mock, mock.patch( integration.__name__ + ".safe_patch", wraps=safe_patch ) as safe_patch_mock: - # In `mlflow.xgboost.autolog()`, we enable autologging for XGBoost sklearn - # models using `mlflow.sklearn._autolog()`. So besides `safe_patch` calls in - # `mlflow.xgboost.autolog()`, we need to count additional `safe_patch` calls + # In `mlflow.xgboost.autolog()` and `mlflow.lightgbm.autolog()`, + # we enable autologging for XGBoost and LightGBM sklearn models + # using `mlflow.sklearn._autolog()`. So besides `safe_patch` calls in + # `autolog()`, we need to count additional `safe_patch` calls # in sklearn autologging routine as well. - if integration.__name__ == "mlflow.xgboost": + if integration.__name__ in ["mlflow.xgboost", "mlflow.lightgbm"]: with mock.patch( "mlflow.sklearn.safe_patch", wraps=safe_patch - ) as xgb_sklearn_safe_patch_mock: + ) as sklearn_safe_patch_mock: integration.autolog(disable=False) safe_patch_call_count = ( - safe_patch_mock.call_count + xgb_sklearn_safe_patch_mock.call_count + safe_patch_mock.call_count + sklearn_safe_patch_mock.call_count ) else: integration.autolog(disable=False) diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 3b8bef3a4fa78..423e482f85d57 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -82,9 +82,10 @@ def report_free_disk_space(capsys): ("hyperparam", ["-e", "gpyopt", "-P", "epochs=1"]), ("hyperparam", ["-e", "hyperopt", "-P", "epochs=1"]), ( - "lightgbm", + os.path.join("lightgbm", "lightgbm_native"), ["-P", "learning_rate=0.1", "-P", "colsample_bytree=0.8", "-P", "subsample=0.9"], ), + (os.path.join("lightgbm", "lightgbm_sklearn"), []), ("statsmodels", ["-P", "inverse_method=qr"]), ("pytorch", ["-P", "epochs=2"]), ("sklearn_logistic_regression", []), @@ -140,7 +141,7 @@ def test_mlflow_run_example(directory, params, tmpdir): ("gluon", ["python", "train.py"]), ("keras", ["python", "train.py"]), ( - "lightgbm", + os.path.join("lightgbm", "lightgbm_native"), [ "python", "train.py", @@ -152,6 +153,7 @@ def test_mlflow_run_example(directory, params, tmpdir): "0.9", ], ), + (os.path.join("lightgbm", "lightgbm_sklearn"), ["python", "train.py"]), ("statsmodels", ["python", "train.py", "--inverse-method", "qr"]), ("quickstart", ["python", "mlflow_tracking.py"]), ("remote_store", ["python", "remote_server.py"]), diff --git a/tests/lightgbm/test_lightgbm_autolog.py b/tests/lightgbm/test_lightgbm_autolog.py index 013df418e6605..cdb8248401c2f 100644 --- a/tests/lightgbm/test_lightgbm_autolog.py +++ b/tests/lightgbm/test_lightgbm_autolog.py @@ -157,6 +157,35 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set): assert param not in params +@pytest.mark.large +def test_lgb_autolog_sklearn(): + + mlflow.lightgbm.autolog() + + X, y = datasets.load_iris(return_X_y=True) + params = {"n_estimators": 10, "reg_lambda": 1} + model = lgb.LGBMClassifier(**params) + + with mlflow.start_run() as run: + model.fit(X, y) + model_uri = mlflow.get_artifact_uri("model") + + client = mlflow.tracking.MlflowClient() + run = client.get_run(run.info.run_id) + assert run.data.metrics.items() <= params.items() + artifacts = set(x.path for x in client.list_artifacts(run.info.run_id)) + assert artifacts >= set( + [ + "feature_importance_gain.png", + "feature_importance_gain.json", + "feature_importance_split.png", + "feature_importance_split.json", + ] + ) + loaded_model = mlflow.lightgbm.load_model(model_uri) + np.testing.assert_allclose(loaded_model.predict(X), model.predict(X)) + + @pytest.mark.large def test_lgb_autolog_logs_metrics_with_validation_data(bst_params, train_set): mlflow.lightgbm.autolog()