Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Autologging functionality for scikit-learn integration with LightGBM …
…(Part 2) (#5200) * init commit, to-do: examples Signed-off-by: Junwen Yao <jwyiao@gmail.com> * add examples, update doc Signed-off-by: Junwen Yao <jwyiao@gmail.com> * re-start example test Signed-off-by: Junwen Yao <jwyiao@gmail.com> * update Signed-off-by: Junwen Yao <jwyiao@gmail.com> * check sagemaker Signed-off-by: Junwen Yao <jwyiao@gmail.com> * [resolve conflict] update Signed-off-by: Junwen Yao <jwyiao@gmail.com>
- Loading branch information
Showing
16 changed files
with
228 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,4 @@ | ||
# LightGBM Example | ||
|
||
This example trains a LightGBM classifier with the iris dataset and logs hyperparameters, metrics, and trained model. | ||
# Examples for LightGBM Autologging | ||
|
||
## Running the code | ||
|
||
``` | ||
python train.py --colsample-bytree 0.8 --subsample 0.9 | ||
``` | ||
You can try experimenting with different parameter values like: | ||
``` | ||
python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 | ||
``` | ||
|
||
Then you can open the MLflow UI to track the experiments and compare your runs via: | ||
``` | ||
mlflow ui | ||
``` | ||
|
||
## Running the code as a project | ||
|
||
``` | ||
mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 | ||
``` | ||
LightGBM autologging functionalities are demonstrated through two examples. The first example in the `lightgbm_native` folder logs a Booster model trained by `xgboost.train()`. The second example in the `lightgbm_sklearn` folder shows how autologging works for LightGBM scikit-learn models. The autologging for all LightGBM models is enabled via `mlflow.xgboost.autolog()`. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# LightGBM Example | ||
|
||
This example trains a LightGBM classifier with the iris dataset and logs hyperparameters, metrics, and trained model. | ||
|
||
## Running the code | ||
|
||
``` | ||
python train.py --colsample-bytree 0.8 --subsample 0.9 | ||
``` | ||
You can try experimenting with different parameter values like: | ||
``` | ||
python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 | ||
``` | ||
|
||
Then you can open the MLflow UI to track the experiments and compare your runs via: | ||
``` | ||
mlflow ui | ||
``` | ||
|
||
## Running the code as a project | ||
|
||
``` | ||
mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 | ||
``` |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: lightgbm-sklearn-example | ||
conda_env: conda.yaml | ||
entry_points: | ||
main: | ||
command: python train.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# XGBoost Scikit-learn Model Example | ||
|
||
This example trains an [`LightGBM.LGBMClassifier`](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html) with the diabetes dataset and logs hyperparameters, metrics, and trained model. | ||
|
||
Like the other LightGBM example, we enable autologging for LightGBM scikit-learn models via `mlflow.lightgbm.autolog()`. Saving / loading models also supports LightGBM scikit-learn models. | ||
|
||
You can run this example using the following command: | ||
|
||
``` python | ||
python train.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
name: lightgbm-example | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.6 | ||
- pip | ||
- pip: | ||
- mlflow>=1.6.0 | ||
- matplotlib | ||
- lightgbm | ||
- cloudpickle>=2.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pprint import pprint | ||
|
||
import lightgbm as lgb | ||
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import f1_score | ||
|
||
import mlflow | ||
import mlflow.lightgbm | ||
|
||
from utils import fetch_logged_data | ||
|
||
|
||
def main(): | ||
# prepare example dataset | ||
X, y = load_iris(return_X_y=True, as_frame=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y) | ||
|
||
# enable auto logging | ||
# this includes lightgbm.sklearn estimators | ||
mlflow.lightgbm.autolog() | ||
|
||
with mlflow.start_run() as run: | ||
|
||
regressor = lgb.LGBMClassifier(n_estimators=20, reg_lambda=1.0) | ||
regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)]) | ||
y_pred = regressor.predict(X_test) | ||
f1 = f1_score(y_test, y_pred, average="micro") | ||
run_id = run.info.run_id | ||
print("Logged data and model in run {}".format(run_id)) | ||
|
||
# show logged data | ||
for key, data in fetch_logged_data(run.info.run_id).items(): | ||
print("\n---------- logged {} ----------".format(key)) | ||
pprint(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import mlflow | ||
|
||
|
||
def yield_artifacts(run_id, path=None): | ||
"""Yield all artifacts in the specified run""" | ||
client = mlflow.tracking.MlflowClient() | ||
for item in client.list_artifacts(run_id, path): | ||
if item.is_dir: | ||
yield from yield_artifacts(run_id, item.path) | ||
else: | ||
yield item.path | ||
|
||
|
||
def fetch_logged_data(run_id): | ||
"""Fetch params, metrics, tags, and artifacts in the specified run""" | ||
client = mlflow.tracking.MlflowClient() | ||
data = client.get_run(run_id).data | ||
# Exclude system tags: https://www.mlflow.org/docs/latest/tracking.html#system-tags | ||
tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} | ||
artifacts = list(yield_artifacts(run_id)) | ||
return { | ||
"params": data.params, | ||
"metrics": data.metrics, | ||
"tags": tags, | ||
"artifacts": artifacts, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.