diff --git a/mlflow/sklearn/__init__.py b/mlflow/sklearn/__init__.py index d2d6b21f48a53..d8f561290efc9 100644 --- a/mlflow/sklearn/__init__.py +++ b/mlflow/sklearn/__init__.py @@ -1196,8 +1196,8 @@ def _autolog( _MIN_SKLEARN_VERSION, _TRAINING_PREFIX, _is_supported_version, + _get_X_y_and_sample_weight, _gen_xgboost_sklearn_estimators_to_patch, - _get_args_for_metrics, _log_estimator_content, _all_estimators, _get_estimator_info_tags, @@ -1233,7 +1233,7 @@ def fit_mlflow_xgboost(original, self, *args, **kwargs): # are done in `train()` in `mlflow.xgboost.autolog()` fit_output = original(self, *args, **kwargs) # log models after training - X = _get_args_for_metrics(self.fit, args, kwargs)[0] + X = _get_X_y_and_sample_weight(self.fit, args, kwargs)[0] if log_models: input_example, signature = resolve_input_example_and_signature( lambda: X[:INPUT_EXAMPLE_SAMPLE_ROWS], @@ -1322,7 +1322,7 @@ def infer_model_signature(input_example): return infer_signature(input_example, estimator.predict(input_example)) - (X, y_true, sample_weight) = _get_args_for_metrics(estimator.fit, args, kwargs) + (X, y_true, sample_weight) = _get_X_y_and_sample_weight(estimator.fit, args, kwargs) # log common metrics and artifacts for estimators (classifier, regressor) logged_metrics = _log_estimator_content( diff --git a/mlflow/sklearn/utils.py b/mlflow/sklearn/utils.py index 09cc44a319e21..e77fb2755b284 100644 --- a/mlflow/sklearn/utils.py +++ b/mlflow/sklearn/utils.py @@ -58,9 +58,9 @@ def _get_estimator_info_tags(estimator): } -def _get_args_for_metrics(fit_func, fit_args, fit_kwargs): +def _get_X_y_and_sample_weight(fit_func, fit_args, fit_kwargs): """ - Get arguments to pass to metric computations in the following steps. + Get a tuple of (X, y, sample_weight) in the following steps. 1. Extract X and y from fit_args and fit_kwargs. 2. If the sample_weight argument exists in fit_func,