Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Autologging functionality for scikit-learn integration with XGBoost …
…(Part 2) (#5078) * new commit, resolve conflict Signed-off-by: Junwen Yao <jwyiao@gmail.com> * add example Signed-off-by: Junwen Yao <jwyiao@gmail.com> * fix lint Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review Signed-off-by: Junwen Yao <jwyiao@gmail.com> * fix build_doc Signed-off-by: Junwen Yao <jwyiao@gmail.com> * Update mlflow/sklearn/__init__.py remove additional lines Signed-off-by: Junwen Yao <jwyiao@gmail.com> * remove extra lines Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review > TODO:(1)doc(2)test Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review + add tests > TODO:doc,README Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review + complete doc Signed-off-by: Junwen Yao <jwyiao@gmail.com> * fix lint Signed-off-by: Junwen Yao <jwyiao@gmail.com> * update examples + fix example tests Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review Signed-off-by: Junwen Yao <jwyiao@gmail.com> * address review: update example Signed-off-by: Junwen Yao <jwyiao@gmail.com>
- Loading branch information
Showing
16 changed files
with
292 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,3 @@ | ||
# XGBoost Example | ||
# Examples for XGBoost Autologging | ||
|
||
This example trains an XGBoost classifier with the iris dataset and logs hyperparameters, metrics, and trained model. | ||
|
||
## Running the code | ||
|
||
``` | ||
python train.py --learning-rate 0.2 --colsample-bytree 0.8 --subsample 0.9 | ||
``` | ||
You can try experimenting with different parameter values like: | ||
``` | ||
python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 | ||
``` | ||
|
||
Then you can open the MLflow UI to track the experiments and compare your runs via: | ||
``` | ||
mlflow ui | ||
``` | ||
|
||
|
||
## Running the code as a project | ||
|
||
``` | ||
mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 | ||
``` | ||
Two examples are provided to demonstrate XGBoost autologging functionalities. The `xgboost_native` folder contains an example that logs a Booster model trained by `xgboost.train()`. The `xgboost_sklearn` includes another example showing how autologging works for XGBoost scikit-learn models. In fact, there is no difference in turning on autologging for all XGBoost models. That is, `mlflow.xgboost.autolog()` works for all XGBoost models. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# XGBoost Example | ||
|
||
This example trains an XGBoost classifier with the iris dataset and logs hyperparameters, metrics, and trained model. | ||
|
||
## Running the code | ||
|
||
``` | ||
python train.py --learning-rate 0.2 --colsample-bytree 0.8 --subsample 0.9 | ||
``` | ||
You can try experimenting with different parameter values like: | ||
``` | ||
python train.py --learning-rate 0.4 --colsample-bytree 0.7 --subsample 0.8 | ||
``` | ||
|
||
Then you can open the MLflow UI to track the experiments and compare your runs via: | ||
``` | ||
mlflow ui | ||
``` | ||
|
||
|
||
## Running the code as a project | ||
|
||
``` | ||
mlflow run . -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9 | ||
``` |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
name: xgboost_sklearn_example | ||
|
||
conda_env: conda.yaml | ||
|
||
entry_points: | ||
main: | ||
command: "python train.py" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# XGBoost Scikit-learn Model Example | ||
|
||
This example trains an [`XGBoost.XGBRegressor`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.XGBRegressor) with the diabetes dataset and logs hyperparameters, metrics, and trained model. | ||
|
||
Like the other XGBoost example, we enable autologging for XGBoost scikit-learn models via `mlflow.xgboost.autolog()`. Saving / loading models also supports XGBoost scikit-learn models. | ||
|
||
You can run this example using the following command: | ||
``` | ||
python train_sklearn.py | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.8.12 | ||
- pip | ||
- pip: | ||
- mlflow | ||
- pandas==1.3.4 | ||
- scikit-learn==0.24.2 | ||
- xgboost==1.5.0 | ||
name: mlflow-env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from pprint import pprint | ||
|
||
import xgboost as xgb | ||
from sklearn.datasets import load_diabetes | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import mean_squared_error | ||
|
||
import mlflow | ||
import mlflow.xgboost | ||
|
||
from utils import fetch_logged_data | ||
|
||
|
||
def main(): | ||
# prepare example dataset | ||
X, y = load_diabetes(return_X_y=True, as_frame=True) | ||
X_train, X_test, y_train, y_test = train_test_split(X, y) | ||
|
||
# enable auto logging | ||
# this includes xgboost.sklearn estimators | ||
mlflow.xgboost.autolog() | ||
|
||
with mlflow.start_run() as run: | ||
|
||
regressor = xgb.XGBRegressor(n_estimators=20, reg_lambda=1, gamma=0, max_depth=3) | ||
regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)]) | ||
y_pred = regressor.predict(X_test) | ||
mse = mean_squared_error(y_test, y_pred) | ||
run_id = run.info.run_id | ||
print("Logged data and model in run {}".format(run_id)) | ||
|
||
# show logged data | ||
for key, data in fetch_logged_data(run.info.run_id).items(): | ||
print("\n---------- logged {} ----------".format(key)) | ||
pprint(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import mlflow | ||
|
||
|
||
def yield_artifacts(run_id, path=None): | ||
"""Yield all artifacts in the specified run""" | ||
client = mlflow.tracking.MlflowClient() | ||
for item in client.list_artifacts(run_id, path): | ||
if item.is_dir: | ||
yield from yield_artifacts(run_id, item.path) | ||
else: | ||
yield item.path | ||
|
||
|
||
def fetch_logged_data(run_id): | ||
"""Fetch params, metrics, tags, and artifacts in the specified run""" | ||
client = mlflow.tracking.MlflowClient() | ||
data = client.get_run(run_id).data | ||
# Exclude system tags: https://www.mlflow.org/docs/latest/tracking.html#system-tags | ||
tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} | ||
artifacts = list(yield_artifacts(run_id)) | ||
return { | ||
"params": data.params, | ||
"metrics": data.metrics, | ||
"tags": tags, | ||
"artifacts": artifacts, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.