-
Notifications
You must be signed in to change notification settings - Fork 4k
/
train.py
39 lines (28 loc) · 1.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from pprint import pprint
import xgboost as xgb
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import mlflow
import mlflow.xgboost
from utils import fetch_logged_data
def main():
# prepare example dataset
X, y = load_diabetes(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
# enable auto logging
# this includes xgboost.sklearn estimators
mlflow.xgboost.autolog()
with mlflow.start_run() as run:
regressor = xgb.XGBRegressor(n_estimators=20, reg_lambda=1, gamma=0, max_depth=3)
regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)])
y_pred = regressor.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
run_id = run.info.run_id
print("Logged data and model in run {}".format(run_id))
# show logged data
for key, data in fetch_logged_data(run.info.run_id).items():
print("\n---------- logged {} ----------".format(key))
pprint(data)
if __name__ == "__main__":
main()