diff --git a/mlflow/lightgbm.py b/mlflow/lightgbm.py index 764241f867442..b8ab5df94d88d 100644 --- a/mlflow/lightgbm.py +++ b/mlflow/lightgbm.py @@ -33,6 +33,7 @@ from mlflow.models.signature import ModelSignature from mlflow.models.utils import ModelInputExample, _save_example from mlflow.tracking.artifact_utils import _download_artifact_from_uri +from mlflow.utils import _get_fully_qualified_class_name from mlflow.utils.environment import ( _mlflow_conda_env, _validate_env_arguments, @@ -67,21 +68,24 @@ _logger = logging.getLogger(__name__) -def get_default_pip_requirements(): +def get_default_pip_requirements(include_cloudpickle=False): """ :return: A list of default pip requirements for MLflow Models produced by this flavor. Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment that, at minimum, contains these requirements. """ - return [_get_pinned_requirement("lightgbm")] + pip_deps = [_get_pinned_requirement("lightgbm")] + if include_cloudpickle: + pip_deps.append(_get_pinned_requirement("cloudpickle")) + return pip_deps -def get_default_conda_env(): +def get_default_conda_env(include_cloudpickle=False): """ :return: The default Conda environment for MLflow Models produced by calls to :func:`save_model()` and :func:`log_model()`. """ - return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) + return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements(include_cloudpickle)) @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) @@ -132,7 +136,7 @@ def save_model( path = os.path.abspath(path) if os.path.exists(path): raise MlflowException("Path '{}' already exists".format(path)) - model_data_subpath = "model.lgb" + model_data_subpath = "model.lgb" if isinstance(lgb_model, lgb.Booster) else "model.pkl" model_data_path = os.path.join(path, model_data_subpath) os.makedirs(path) if mlflow_model is None: @@ -143,20 +147,28 @@ def save_model( _save_example(mlflow_model, input_example, path) # Save a LightGBM model - lgb_model.save_model(model_data_path) + _save_model(lgb_model, model_data_path) + lgb_model_class = _get_fully_qualified_class_name(lgb_model) pyfunc.add_to_model( mlflow_model, loader_module="mlflow.lightgbm", data=model_data_subpath, env=_CONDA_ENV_FILE_NAME, ) - mlflow_model.add_flavor(FLAVOR_NAME, lgb_version=lgb.__version__, data=model_data_subpath) + mlflow_model.add_flavor( + FLAVOR_NAME, + lgb_version=lgb.__version__, + data=model_data_subpath, + model_class=lgb_model_class, + ) mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) if conda_env is None: if pip_requirements is None: - default_reqs = get_default_pip_requirements() + default_reqs = get_default_pip_requirements( + include_cloudpickle=not isinstance(lgb_model, lgb.Booster) + ) # To ensure `_load_pyfunc` can successfully load the model during the dependency # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. inferred_reqs = mlflow.models.infer_pip_requirements( @@ -186,6 +198,22 @@ def save_model( write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) +def _save_model(lgb_model, model_path): + """ + LightGBM Boosters are saved using the built-in method `save_model()`, + whereas LightGBM scikit-learn models are serialized using Cloudpickle. + """ + import lightgbm as lgb + + if isinstance(lgb_model, lgb.Booster): + lgb_model.save_model(model_path) + else: + import cloudpickle + + with open(model_path, "wb") as out: + cloudpickle.dump(lgb_model, out) + + @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME)) def log_model( lgb_model, @@ -251,9 +279,31 @@ def log_model( def _load_model(path): - import lightgbm as lgb + """ + Load Model Implementation. + :param path: Local filesystem path to + the MLflow Model with the ``lightgbm`` flavor (MLflow < 1.23.0) or + the top-level MLflow Model directory (MLflow >= 1.23.0). + """ + + model_dir = os.path.dirname(path) if os.path.isfile(path) else path + flavor_conf = _get_flavor_configuration(model_path=model_dir, flavor_name=FLAVOR_NAME) + + model_class = flavor_conf.get("model_class", "lightgbm.basic.Booster") + lgb_model_path = os.path.join(model_dir, flavor_conf.get("data")) + + if model_class == "lightgbm.basic.Booster": + import lightgbm as lgb + + model = lgb.Booster(model_file=lgb_model_path) + else: + # LightGBM scikit-learn models are deserialized using Cloudpickle. + import cloudpickle + + with open(lgb_model_path, "rb") as f: + model = cloudpickle.load(f) - return lgb.Booster(model_file=path) + return model def _load_pyfunc(path): @@ -283,12 +333,11 @@ def load_model(model_uri, dst_path=None): This directory must already exist. If unspecified, a local output path will be created. - :return: A LightGBM model (an instance of `lightgbm.Booster`_). + :return: A LightGBM model (an instance of `lightgbm.Booster`_) or a LightGBM scikit-learn + model, depending on the saved model class specification. """ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) - flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME) - lgb_model_file_path = os.path.join(local_model_path, flavor_conf.get("data", "model.lgb")) - return _load_model(path=lgb_model_file_path) + return _load_model(path=local_model_path) class _LGBModelWrapper: diff --git a/tests/lightgbm/test_lightgbm_model_export.py b/tests/lightgbm/test_lightgbm_model_export.py index 8b973e1701d4f..b08ece0f67710 100644 --- a/tests/lightgbm/test_lightgbm_model_export.py +++ b/tests/lightgbm/test_lightgbm_model_export.py @@ -50,6 +50,18 @@ def lgb_model(): return ModelWithData(model=model, inference_dataframe=X) +@pytest.fixture(scope="session") +def lgb_sklearn_model(): + iris = datasets.load_iris() + X = pd.DataFrame( + iris.data[:, :2], columns=iris.feature_names[:2] # we only take the first two features. + ) + y = iris.target + model = lgb.LGBMClassifier(n_estimators=10) + model.fit(X, y) + return ModelWithData(model=model, inference_dataframe=X) + + @pytest.fixture def model_path(tmpdir): return os.path.join(str(tmpdir), "model") @@ -68,7 +80,7 @@ def test_model_save_load(lgb_model, model_path): mlflow.lightgbm.save_model(lgb_model=model, path=model_path) reloaded_model = mlflow.lightgbm.load_model(model_uri=model_path) - reloaded_pyfunc = pyfunc.load_pyfunc(model_uri=model_path) + reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) np.testing.assert_array_almost_equal( model.predict(lgb_model.inference_dataframe), @@ -81,6 +93,24 @@ def test_model_save_load(lgb_model, model_path): ) +@pytest.mark.large +def test_sklearn_model_save_load(lgb_sklearn_model, model_path): + model = lgb_sklearn_model.model + mlflow.lightgbm.save_model(lgb_model=model, path=model_path) + reloaded_model = mlflow.lightgbm.load_model(model_uri=model_path) + reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) + + np.testing.assert_array_almost_equal( + model.predict(lgb_sklearn_model.inference_dataframe), + reloaded_model.predict(lgb_sklearn_model.inference_dataframe), + ) + + np.testing.assert_array_almost_equal( + reloaded_model.predict(lgb_sklearn_model.inference_dataframe), + reloaded_pyfunc.predict(lgb_sklearn_model.inference_dataframe), + ) + + def test_signature_and_examples_are_saved_correctly(lgb_model): model = lgb_model.model X = lgb_model.inference_dataframe @@ -398,3 +428,49 @@ def test_pyfunc_serve_and_score_sklearn(model): ) scores = pd.read_json(resp.content, orient="records").values.squeeze() np.testing.assert_array_equal(scores, model.predict(X.head(3))) + + +@pytest.mark.large +def test_load_pyfunc_succeeds_for_older_models_with_pyfunc_data_field(lgb_model, model_path): + """ + This test verifies that LightGBM models saved in older versions of MLflow are loaded + successfully by ``mlflow.pyfunc.load_model``. These older models specify a pyfunc ``data`` + field referring directly to a LightGBM model file. Newer models also have the + ``model_class`` in LightGBM flavor. + """ + model = lgb_model.model + mlflow.lightgbm.save_model(lgb_model=model, path=model_path) + + model_conf_path = os.path.join(model_path, "MLmodel") + model_conf = Model.load(model_conf_path) + pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME) + lgb_conf = model_conf.flavors.get(mlflow.lightgbm.FLAVOR_NAME) + assert lgb_conf is not None + assert "model_class" in lgb_conf + assert "data" in lgb_conf + assert pyfunc_conf is not None + assert "model_class" not in pyfunc_conf + assert pyfunc.DATA in pyfunc_conf + + # test old MLmodel conf + model_conf.flavors["lightgbm"] = {"lgb_version": lgb.__version__, "data": "model.lgb"} + model_conf.save(model_conf_path) + model_conf = Model.load(model_conf_path) + lgb_conf = model_conf.flavors.get(mlflow.lightgbm.FLAVOR_NAME) + assert "data" in lgb_conf + assert lgb_conf["data"] == "model.lgb" + + reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) + assert isinstance(reloaded_pyfunc._model_impl.lgb_model, lgb.Booster) + reloaded_lgb = mlflow.lightgbm.load_model(model_uri=model_path) + assert isinstance(reloaded_lgb, lgb.Booster) + + np.testing.assert_array_almost_equal( + lgb_model.model.predict(lgb_model.inference_dataframe), + reloaded_pyfunc.predict(lgb_model.inference_dataframe), + ) + + np.testing.assert_array_almost_equal( + reloaded_lgb.predict(lgb_model.inference_dataframe), + reloaded_pyfunc.predict(lgb_model.inference_dataframe), + )