Skip to content

Commit

Permalink
Return Model from log_model (#5230)
Browse files Browse the repository at this point in the history
Provide a return value from mlflow.*.log_model() that contains model metadata such as model_uri, run_id,...
This metadata makes it easier for the user to do follow-up model loading.

Signed-off-by: Liang Zhang <liang.zhang@databricks.com>
  • Loading branch information
liangz1 committed Jan 12, 2022
1 parent c2ae0bd commit 364aca7
Show file tree
Hide file tree
Showing 39 changed files with 204 additions and 34 deletions.
2 changes: 2 additions & 0 deletions docs/source/python_api/mlflow.models.rst
Expand Up @@ -6,3 +6,5 @@ mlflow.models
:undoc-members:
:show-inheritance:

.. autoclass:: mlflow.models.model.ModelInfo
:members:
4 changes: 3 additions & 1 deletion mlflow/catboost.py
Expand Up @@ -226,8 +226,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `CatBoost.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.catboost,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/fastai/__init__.py
Expand Up @@ -252,6 +252,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -287,7 +289,7 @@ def main(epochs=5, learning_rate=0.01):
artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.fastai']
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.fastai,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/gluon/__init__.py
Expand Up @@ -307,6 +307,8 @@ def log_model(
by converting it to a list. Bytes are base64-encoded.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand All @@ -331,7 +333,7 @@ def log_model(
est.fit(train_data=train_data, epochs=100, val_data=validation_data)
mlflow.gluon.log_model(net, "model")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.gluon,
gluon_model=gluon_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/h2o.py
Expand Up @@ -210,8 +210,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``h2o.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.h2o,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/keras.py
Expand Up @@ -361,6 +361,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``keras_model.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand All @@ -376,7 +378,7 @@ def log_model(
with mlflow.start_run() as run:
mlflow.keras.log_model(keras_model, "models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.keras,
keras_model=keras_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/lightgbm.py
Expand Up @@ -263,8 +263,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/mleap.py
Expand Up @@ -74,6 +74,8 @@ def log_model(
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
Expand Down
78 changes: 75 additions & 3 deletions mlflow/models/model.py
Expand Up @@ -6,7 +6,7 @@
import os
import uuid

from typing import Any, Dict, Optional, Union, Callable
from typing import Any, Dict, Optional, Union, Callable, NamedTuple

import mlflow
from mlflow.exceptions import MlflowException
Expand All @@ -22,12 +22,60 @@
"Logging model metadata to the tracking server has failed, possibly due older "
"server version. The model artifacts have been logged successfully under %s. "
"In addition to exporting model artifacts, MLflow clients 1.7.0 and above "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"1.7.0 or above."
)


class ModelInfo(NamedTuple):
"""
The metadata of a logged MLflow Model.
"""

#: Run relative path identifying the logged model.
artifact_path: str

#: A dictionary mapping the flavor name to how to serve the model as that flavor. For example:
#:
#: .. code-block:: python
#:
#: {
#: "python_function": {
#: "model_path": "model.pkl",
#: "loader_module": "mlflow.sklearn",
#: "python_version": "3.8.10",
#: "env": "conda.yaml",
#: },
#: "sklearn": {
#: "pickled_model": "model.pkl",
#: "sklearn_version": "0.24.1",
#: "serialization_format": "cloudpickle",
#: },
#: }
flavors: Dict[str, Any]

#: The ``model_uri`` of the logged model in the format ``'runs:/<run_id>/<artifact_path>'``.
model_uri: str

#: The ``model_uuid`` of the logged model, e.g., ``'39ca11813cfc46b09ab83972740b80ca'``.
model_uuid: str

#: The ``run_id`` associated with the logged model, e.g., ``'8ede7df408dd42ed9fc39019ef7df309'``
run_id: str

#: A dictionary that contains the metadata of the saved input example, e.g.,
#: ``{"artifact_path": "input_example.json", "type": "dataframe", "pandas_orient": "split"}``.
saved_input_example_info: Optional[Dict[str, Any]]

#: A dictionary that describes the model input and output generated by
# :py:meth:`ModelSignature.to_dict() <mlflow.models.ModelSignature.to_dict>`.
signature_dict: Optional[Dict[str, Any]]

#: The UTC time that the logged model is created, e.g., ``'2022-01-12 05:17:31.634689'``.
utc_time_created: str


class Model:
"""
An MLflow Model that can support multiple model flavors. Provides APIs for implementing
Expand Down Expand Up @@ -102,13 +150,33 @@ def signature(self, value):

@property
def saved_input_example_info(self) -> Optional[Dict[str, Any]]:
"""
A dictionary that contains the metadata of the saved input example, e.g.,
``{"artifact_path": "input_example.json", "type": "dataframe", "pandas_orient": "split"}``.
"""
return self._saved_input_example_info

@saved_input_example_info.setter
def saved_input_example_info(self, value: Dict[str, Any]):
# pylint: disable=attribute-defined-outside-init
self._saved_input_example_info = value

def get_model_info(self):
"""
Create a :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
model metadata.
"""
return ModelInfo(
artifact_path=self.artifact_path,
flavors=self.flavors,
model_uri="runs:/{}/{}".format(self.run_id, self.artifact_path),
model_uuid=self.model_uuid,
run_id=self.run_id,
saved_input_example_info=self.saved_input_example_info,
signature_dict=self.signature.to_dict() if self.signature else None,
utc_time_created=self.utc_time_created,
)

def to_dict(self):
"""Serialize the model to a dictionary."""
res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
Expand Down Expand Up @@ -203,6 +271,9 @@ def log(
waits for five minutes. Specify 0 or None to skip waiting.
:param kwargs: Extra args passed to the model flavor.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
with TempDir() as tmp:
local_path = tmp.path("model")
Expand All @@ -223,3 +294,4 @@ def log(
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
4 changes: 3 additions & 1 deletion mlflow/onnx.py
Expand Up @@ -362,8 +362,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.onnx,
onnx_model=onnx_model,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/paddle/__init__.py
Expand Up @@ -371,6 +371,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand Down
4 changes: 3 additions & 1 deletion mlflow/prophet.py
Expand Up @@ -221,8 +221,10 @@ def log_model(
Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.prophet,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -1199,6 +1199,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/pytorch/__init__.py
Expand Up @@ -209,6 +209,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``torch.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -283,7 +285,7 @@ def gen_data():
PyTorch logged models
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/sklearn/__init__.py
Expand Up @@ -350,6 +350,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand All @@ -371,7 +373,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/spacy.py
Expand Up @@ -223,8 +223,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``spacy.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.spacy,
registered_model_name=registered_model_name,
Expand Down
3 changes: 3 additions & 0 deletions mlflow/spark.py
Expand Up @@ -170,6 +170,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -259,6 +261,7 @@ def log_model(
registered_model_name,
await_registration_for,
)
return mlflow_model.get_model_info()


def _tmp_path(dfs_tmp):
Expand Down
4 changes: 3 additions & 1 deletion mlflow/statsmodels.py
Expand Up @@ -251,8 +251,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.statsmodels,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/tensorflow/__init__.py
Expand Up @@ -177,6 +177,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/xgboost/__init__.py
Expand Up @@ -240,8 +240,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.xgboost,
registered_model_name=registered_model_name,
Expand Down
3 changes: 2 additions & 1 deletion tests/catboost/test_catboost_model_export.py
Expand Up @@ -195,8 +195,9 @@ def test_log_model(cb_model, tmpdir):
conda_env = os.path.join(tmpdir.strpath, "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["catboost"])

mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_info = mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_uri = "runs:/{}/{}".format(mlflow.active_run().info.run_id, artifact_path)
assert model_info.model_uri == model_uri

loaded_model = mlflow.catboost.load_model(model_uri)
np.testing.assert_array_almost_equal(
Expand Down
3 changes: 2 additions & 1 deletion tests/fastai/test_fastai_model_export.py
Expand Up @@ -149,13 +149,14 @@ def test_model_log(fastai_model, model_path):
conda_env = os.path.join(tmp.path(), "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["fastai"])

mlflow.fastai.log_model(
model_info = mlflow.fastai.log_model(
fastai_learner=model, artifact_path=artifact_path, conda_env=conda_env
)

model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

reloaded_model = mlflow.fastai.load_model(model_uri=model_uri)

Expand Down

0 comments on commit 364aca7

Please sign in to comment.