New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autologging functionality for scikit-learn integration with XGBoost (Part 2) #5055
Conversation
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
@@ -365,7 +378,7 @@ def log_model( | |||
# log model | |||
mlflow.sklearn.log_model(sk_model, "sk_models") | |||
""" | |||
return Model.log( | |||
Model.log( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems Model.log()
doesn't return any value. Maybe we can remove return
.
|
||
|
||
def _autolog( | ||
flavor_name=FLAVOR_NAME, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Internal API for sklearn autologging. The flavor_name
field allows mlflow.xgboost
to specify the xgboost_sklearn
flavor, preventing flavor conflict with mlflow.sklearn
.
|
||
def _mlflow_xgboost_logging( | ||
importance_types, autologging_client, logger, original, sklearn_estimator, *args, **kwargs, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-organize early stopping call backs and feature importance plot. This function is re-used in mlflow.sklearn
for logging XGBoost sklearn estimators.
|
||
safe_patch_call_count = ( | ||
safe_patch_mock.call_count + xgb_sklearn_safe_patch_mock.call_count | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since mlflow.sklearn._autolog()
is called inside mlflow.xgboost
, we need to count safe_patch
called due to enabling sklearn autologging.
Hi @harupy @dbczumar, I made a new PR to complete the autologging functionality for XGBoost sklearn estimators. It is based on our previous discussion #4885. I left a few comments in the PR to highlight the changes:
Please correct me if I missed anything. Also please let me know your feedback and suggestions! Thanks a lot! |
Regarding the tests, I was trying to integrate XGBoost sklearn estimators tests with the existing tests: change # current
expected_params = {"num_boost_round": 20, "early_stopping_rounds": 5, "verbose_eval": False}
xgb.train(bst_params, dtrain, evals=[(dtrain, "train")], **expected_params) to something like # new
def xgb_train(mode, bst_params, data, other_kwargs):
if mode == "xgboost_sklearn":
# return XGBoost sklearn model using bst_params and other_kwargs
else:
# mode == "xgboost"
# return xgb.train(...)
# insider a test function
xgb_train(mode, bst_params, data, other_kwargs) but the integration could be messy in this way. Not all parameters passed to xgboost.train(bst_params, dtrain, **kwargs) but xgb_sklearn_model = xgboost.XGBClassifier(**bst_params, **kwargs)
xgb_sklearn_model.fit(X, y) # X, y from dtrain generally is not error proof. Here is an example:
xgb_classifier = xgb.XGBClassifier(objective="multi:softprob", num_class=3, n_estimators=20)
xgb_classifier.fit(X, y, eval_metric=["merror", "mlogloss"], eval_set=[(X1,y1),(X2,y2)]) @harupy @dbczumar Should we keep doing the integration approach? Or is it a better idea to create new separate tests for XGBoost sklearn models? What are your opinions / suggestions? Thanks! |
Signed-off-by: Junwen Yao <jwyiao@gmail.com>
What changes are proposed in this pull request?
This is the second PR to add autologging for XGBoost sklearn models using
mlflow.sklearn
autologging routine.(Previous PR: #4954)
(Draft + discussion: #4885)
How is this patch tested?
A new example is provided. Tests will be added later.
Does this PR change the documentation?
ci/circleci: build_doc
check. If it's successful, proceed to thenext step, otherwise fix it.
Details
on the right to open the job page of CircleCI.Artifacts
tab.docs/build/html/index.html
.Release Notes
Is this a user-facing change?
Success merge of this PR will enable autologging for XGBoost scikit-learn models.
What component(s), interfaces, languages, and integrations does this PR affect?
Components
area/artifacts
: Artifact stores and artifact loggingarea/build
: Build and test infrastructure for MLflowarea/docs
: MLflow documentation pagesarea/examples
: Example codearea/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registryarea/models
: MLmodel format, model serialization/deserialization, flavorsarea/projects
: MLproject format, project running backendsarea/scoring
: MLflow Model server, model deployment tools, Spark UDFsarea/server-infra
: MLflow Tracking server backendarea/tracking
: Tracking Service, tracking client APIs, autologgingInterface
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev serverarea/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Modelsarea/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registryarea/windows
: Windows supportLanguage
language/r
: R APIs and clientslanguage/java
: Java APIs and clientslanguage/new
: Proposals for new client languagesIntegrations
integrations/azure
: Azure and Azure ML integrationsintegrations/sagemaker
: SageMaker integrationsintegrations/databricks
: Databricks integrationsHow should the PR be classified in the release notes? Choose one:
rn/breaking-change
- The PR will be mentioned in the "Breaking Changes" sectionrn/none
- No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" sectionrn/feature
- A new user-facing feature worth mentioning in the release notesrn/bug-fix
- A user-facing bug fix worth mentioning in the release notesrn/documentation
- A user-facing documentation change worth mentioning in the release notes