Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Model from log_model #5230

Merged
merged 12 commits into from Jan 12, 2022
2 changes: 2 additions & 0 deletions docs/source/python_api/mlflow.models.rst
Expand Up @@ -6,3 +6,5 @@ mlflow.models
:undoc-members:
:show-inheritance:

.. autoclass:: mlflow.models.model.ModelInfo
:members:
4 changes: 3 additions & 1 deletion mlflow/catboost.py
Expand Up @@ -226,8 +226,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `CatBoost.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.catboost,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/fastai/__init__.py
Expand Up @@ -252,6 +252,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -287,7 +289,7 @@ def main(epochs=5, learning_rate=0.01):

artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.fastai']
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.fastai,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/gluon/__init__.py
Expand Up @@ -307,6 +307,8 @@ def log_model(
by converting it to a list. Bytes are base64-encoded.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -331,7 +333,7 @@ def log_model(
est.fit(train_data=train_data, epochs=100, val_data=validation_data)
mlflow.gluon.log_model(net, "model")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.gluon,
gluon_model=gluon_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/h2o.py
Expand Up @@ -210,8 +210,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``h2o.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.h2o,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/keras.py
Expand Up @@ -361,6 +361,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``keras_model.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -376,7 +378,7 @@ def log_model(
with mlflow.start_run() as run:
mlflow.keras.log_model(keras_model, "models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.keras,
keras_model=keras_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/lightgbm.py
Expand Up @@ -263,8 +263,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/mleap.py
Expand Up @@ -74,6 +74,8 @@ def log_model(
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.

:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.


.. code-block:: python
Expand Down
78 changes: 75 additions & 3 deletions mlflow/models/model.py
Expand Up @@ -6,7 +6,7 @@
import os
import uuid

from typing import Any, Dict, Optional, Union, Callable
from typing import Any, Dict, Optional, Union, Callable, NamedTuple

import mlflow
from mlflow.exceptions import MlflowException
Expand All @@ -22,12 +22,60 @@
"Logging model metadata to the tracking server has failed, possibly due older "
"server version. The model artifacts have been logged successfully under %s. "
"In addition to exporting model artifacts, MLflow clients 1.7.0 and above "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"1.7.0 or above."
)


class ModelInfo(NamedTuple):
"""
The metadata of a logged MLflow Model.
"""

#: Run relative path identifying the logged model.
artifact_path: str

#: A dictionary mapping the flavor name to how to serve the model as that flavor. For example:
#:
#: .. code-block:: python
#:
#: {
#: "python_function": {
#: "model_path": "model.pkl",
#: "loader_module": "mlflow.sklearn",
#: "python_version": "3.8.10",
#: "env": "conda.yaml",
#: },
#: "sklearn": {
#: "pickled_model": "model.pkl",
#: "sklearn_version": "0.24.1",
#: "serialization_format": "cloudpickle",
#: },
#: }
flavors: Dict[str, Any]

#: The ``model_uri`` of the logged model in the format ``'runs:/<run_id>/<artifact_path>'``.
model_uri: str

#: The ``model_uuid`` of the logged model, e.g., ``'39ca11813cfc46b09ab83972740b80ca'``.
model_uuid: str

#: The ``run_id`` associated with the logged model, e.g., ``'8ede7df408dd42ed9fc39019ef7df309'``
run_id: str

#: A dictionary that contains the metadata of the saved input example, e.g.,
#: ``{"artifact_path": "input_example.json", "type": "dataframe", "pandas_orient": "split"}``.
saved_input_example_info: Optional[Dict[str, Any]]

#: A dictionary that describes the model input and output generated by
# :py:meth:`ModelSignature.to_dict() <mlflow.models.ModelSignature.to_dict>`.
signature_dict: Optional[Dict[str, Any]]

#: The UTC time that the logged model is created, e.g., ``'2022-01-12 05:17:31.634689'``.
utc_time_created: str

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added example values for some attributes.


class Model:
"""
An MLflow Model that can support multiple model flavors. Provides APIs for implementing
Expand Down Expand Up @@ -102,13 +150,33 @@ def signature(self, value):

@property
def saved_input_example_info(self) -> Optional[Dict[str, Any]]:
"""
A dictionary that contains the metadata of the saved input example, e.g.,
``{"artifact_path": "input_example.json", "type": "dataframe", "pandas_orient": "split"}``.
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve the Model class property docstring.

return self._saved_input_example_info

@saved_input_example_info.setter
def saved_input_example_info(self, value: Dict[str, Any]):
# pylint: disable=attribute-defined-outside-init
self._saved_input_example_info = value

def get_model_info(self):
"""
Create a :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harupy Also,

:py:class:`ModelInfo <mlflow.models.model.ModelInfo>`

works while the following does not:

:py:class:`ModelInfo`

Thanks for helping!

model metadata.
"""
return ModelInfo(
artifact_path=self.artifact_path,
flavors=self.flavors,
model_uri="runs:/{}/{}".format(self.run_id, self.artifact_path),
model_uuid=self.model_uuid,
run_id=self.run_id,
saved_input_example_info=self.saved_input_example_info,
signature_dict=self.signature.to_dict() if self.signature else None,
utc_time_created=self.utc_time_created,
)

def to_dict(self):
"""Serialize the model to a dictionary."""
res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
Expand Down Expand Up @@ -203,6 +271,9 @@ def log(
waits for five minutes. Specify 0 or None to skip waiting.

:param kwargs: Extra args passed to the model flavor.

:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
with TempDir() as tmp:
local_path = tmp.path("model")
Expand All @@ -223,3 +294,4 @@ def log(
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
4 changes: 3 additions & 1 deletion mlflow/onnx.py
Expand Up @@ -362,8 +362,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.onnx,
onnx_model=onnx_model,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/paddle/__init__.py
Expand Up @@ -371,6 +371,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down
4 changes: 3 additions & 1 deletion mlflow/prophet.py
Expand Up @@ -221,8 +221,10 @@ def log_model(
Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.prophet,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -1199,6 +1199,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/pytorch/__init__.py
Expand Up @@ -209,6 +209,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``torch.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -283,7 +285,7 @@ def gen_data():
PyTorch logged models
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/sklearn/__init__.py
Expand Up @@ -350,6 +350,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -371,7 +373,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/spacy.py
Expand Up @@ -223,8 +223,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``spacy.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.spacy,
registered_model_name=registered_model_name,
Expand Down
3 changes: 3 additions & 0 deletions mlflow/spark.py
Expand Up @@ -170,6 +170,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -259,6 +261,7 @@ def log_model(
registered_model_name,
await_registration_for,
)
return mlflow_model.get_model_info()


def _tmp_path(dfs_tmp):
Expand Down
4 changes: 3 additions & 1 deletion mlflow/statsmodels.py
Expand Up @@ -251,8 +251,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.statsmodels,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/tensorflow/__init__.py
Expand Up @@ -177,6 +177,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/xgboost/__init__.py
Expand Up @@ -240,8 +240,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.xgboost,
registered_model_name=registered_model_name,
Expand Down
3 changes: 2 additions & 1 deletion tests/catboost/test_catboost_model_export.py
Expand Up @@ -195,8 +195,9 @@ def test_log_model(cb_model, tmpdir):
conda_env = os.path.join(tmpdir.strpath, "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["catboost"])

mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_info = mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_uri = "runs:/{}/{}".format(mlflow.active_run().info.run_id, artifact_path)
assert model_info.model_uri == model_uri

loaded_model = mlflow.catboost.load_model(model_uri)
np.testing.assert_array_almost_equal(
Expand Down
3 changes: 2 additions & 1 deletion tests/fastai/test_fastai_model_export.py
Expand Up @@ -149,13 +149,14 @@ def test_model_log(fastai_model, model_path):
conda_env = os.path.join(tmp.path(), "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["fastai"])

mlflow.fastai.log_model(
model_info = mlflow.fastai.log_model(
fastai_learner=model, artifact_path=artifact_path, conda_env=conda_env
)

model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

reloaded_model = mlflow.fastai.load_model(model_uri=model_uri)

Expand Down