Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Model from log_model #5230

Merged
merged 12 commits into from Jan 12, 2022
3 changes: 3 additions & 0 deletions docs/source/python_api/mlflow.models.rst
Expand Up @@ -6,3 +6,6 @@ mlflow.models
:undoc-members:
:show-inheritance:

.. autoclass:: mlflow.models.model.ModelInfo
:members:
:undoc-members:
4 changes: 3 additions & 1 deletion mlflow/catboost.py
Expand Up @@ -226,8 +226,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `CatBoost.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.catboost,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/fastai/__init__.py
Expand Up @@ -252,6 +252,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -287,7 +289,7 @@ def main(epochs=5, learning_rate=0.01):

artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.fastai']
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.fastai,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/gluon/__init__.py
Expand Up @@ -307,6 +307,8 @@ def log_model(
by converting it to a list. Bytes are base64-encoded.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -331,7 +333,7 @@ def log_model(
est.fit(train_data=train_data, epochs=100, val_data=validation_data)
mlflow.gluon.log_model(net, "model")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.gluon,
gluon_model=gluon_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/h2o.py
Expand Up @@ -210,8 +210,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``h2o.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.h2o,
registered_model_name=registered_model_name,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/keras.py
Expand Up @@ -361,6 +361,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``keras_model.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -376,7 +378,7 @@ def log_model(
with mlflow.start_run() as run:
mlflow.keras.log_model(keras_model, "models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.keras,
keras_model=keras_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/lightgbm.py
Expand Up @@ -263,8 +263,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/mleap.py
Expand Up @@ -74,6 +74,8 @@ def log_model(
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.

:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.


.. code-block:: python
Expand Down
55 changes: 49 additions & 6 deletions mlflow/models/model.py
Expand Up @@ -6,13 +6,14 @@
import os
import uuid

from typing import Any, Dict, Optional, Union, Callable
from typing import Any, Dict, Optional, Union, Callable, NamedTuple

import mlflow
from mlflow.exceptions import MlflowException
from mlflow.utils.file_utils import TempDir
from mlflow.utils.databricks_utils import get_databricks_runtime
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.models.signature import ModelSignature

_logger = logging.getLogger(__name__)

Expand All @@ -22,12 +23,37 @@
"Logging model metadata to the tracking server has failed, possibly due older "
"server version. The model artifacts have been logged successfully under %s. "
"In addition to exporting model artifacts, MLflow clients 1.7.0 and above "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"1.7.0 or above."
)


class ModelInfo(NamedTuple):
run_id: str
artifact_path: str
model_uri: str
utc_time_created: str
flavors: Dict[str, Any]
model_uuid: str
saved_input_example_info: Dict[str, Any]
signature: ModelSignature
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
signature: ModelSignature
signature: Optional[ModelSignature]

Can we use Optional for properties that can be None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to a native python dictionary that is ModelSignature.to_dict() to avoid importing pandas in mlflow skinny. Another reason is that for the user's information, there is no need to return the ModelSignature object. A dictionary should be sufficient for the purpose of information reading.



ModelInfo.__doc__ = "The metadata of a logged MLflow Model."
ModelInfo.run_id.__doc__ = "The ``run_id`` associated with the logged model."
ModelInfo.artifact_path.__doc__ = "Run relative path identifying the logged model."
ModelInfo.model_uri.__doc__ = "The ``model_uri`` of the logged model."
ModelInfo.utc_time_created.__doc__ = "The UTC time that the logged model is created."
ModelInfo.flavors.__doc__ = "Flavor module to save the model with."
ModelInfo.model_uuid.__doc__ = "The ``model_uuid`` of the logged model."
ModelInfo.saved_input_example_info.__doc__ = ""
ModelInfo.signature.__doc__ = (
"A :py:class:`ModelSignature <mlflow.models.ModelSignature>` that "
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
"describes the model input and output :py:class:`Schema <mlflow.types.Schema>`."
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added example values for some attributes.


class Model:
"""
An MLflow Model that can support multiple model flavors. Provides APIs for implementing
Expand Down Expand Up @@ -109,6 +135,22 @@ def saved_input_example_info(self, value: Dict[str, Any]):
# pylint: disable=attribute-defined-outside-init
self._saved_input_example_info = value

def get_model_info(self):
"""
Create a :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harupy Also,

:py:class:`ModelInfo <mlflow.models.model.ModelInfo>`

works while the following does not:

:py:class:`ModelInfo`

Thanks for helping!

model metadata.
"""
return ModelInfo(
run_id=self.run_id,
artifact_path=self.artifact_path,
model_uri="runs:/{}/{}".format(self.run_id, self.artifact_path),
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
utc_time_created=self.utc_time_created,
flavors=self.flavors,
model_uuid=self.model_uuid,
saved_input_example_info=self.saved_input_example_info,
signature=self.signature,
)

def to_dict(self):
"""Serialize the model to a dictionary."""
res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
Expand Down Expand Up @@ -148,9 +190,6 @@ def load(cls, path):
@classmethod
def from_dict(cls, model_dict):
"""Load a model from its YAML representation."""

from .signature import ModelSignature

model_dict = model_dict.copy()
if "signature" in model_dict and isinstance(model_dict["signature"], dict):
model_dict["signature"] = ModelSignature.from_dict(model_dict["signature"])
Expand Down Expand Up @@ -203,6 +242,9 @@ def log(
waits for five minutes. Specify 0 or None to skip waiting.

:param kwargs: Extra args passed to the model flavor.

:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
with TempDir() as tmp:
local_path = tmp.path("model")
Expand All @@ -223,3 +265,4 @@ def log(
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
4 changes: 3 additions & 1 deletion mlflow/onnx.py
Expand Up @@ -362,8 +362,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.onnx,
onnx_model=onnx_model,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/paddle/__init__.py
Expand Up @@ -371,6 +371,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down
4 changes: 3 additions & 1 deletion mlflow/prophet.py
Expand Up @@ -221,8 +221,10 @@ def log_model(
Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.prophet,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -1199,6 +1199,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/pytorch/__init__.py
Expand Up @@ -209,6 +209,8 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``torch.save`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -283,7 +285,7 @@ def gen_data():
PyTorch logged models
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/sklearn/__init__.py
Expand Up @@ -350,6 +350,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand All @@ -371,7 +373,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/spacy.py
Expand Up @@ -223,8 +223,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``spacy.save_model`` method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.spacy,
registered_model_name=registered_model_name,
Expand Down
3 changes: 3 additions & 0 deletions mlflow/spark.py
Expand Up @@ -170,6 +170,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -259,6 +261,7 @@ def log_model(
registered_model_name,
await_registration_for,
)
return mlflow_model.get_model_info()


def _tmp_path(dfs_tmp):
Expand Down
4 changes: 3 additions & 1 deletion mlflow/statsmodels.py
Expand Up @@ -251,8 +251,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.statsmodels,
registered_model_name=registered_model_name,
Expand Down
2 changes: 2 additions & 0 deletions mlflow/tensorflow/__init__.py
Expand Up @@ -177,6 +177,8 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
4 changes: 3 additions & 1 deletion mlflow/xgboost/__init__.py
Expand Up @@ -240,8 +240,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method.
:return: A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.xgboost,
registered_model_name=registered_model_name,
Expand Down
3 changes: 2 additions & 1 deletion tests/catboost/test_catboost_model_export.py
Expand Up @@ -195,8 +195,9 @@ def test_log_model(cb_model, tmpdir):
conda_env = os.path.join(tmpdir.strpath, "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["catboost"])

mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_info = mlflow.catboost.log_model(model, artifact_path, conda_env=conda_env)
model_uri = "runs:/{}/{}".format(mlflow.active_run().info.run_id, artifact_path)
assert model_info.model_uri == model_uri

loaded_model = mlflow.catboost.load_model(model_uri)
np.testing.assert_array_almost_equal(
Expand Down
3 changes: 2 additions & 1 deletion tests/fastai/test_fastai_model_export.py
Expand Up @@ -149,13 +149,14 @@ def test_model_log(fastai_model, model_path):
conda_env = os.path.join(tmp.path(), "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["fastai"])

mlflow.fastai.log_model(
model_info = mlflow.fastai.log_model(
fastai_learner=model, artifact_path=artifact_path, conda_env=conda_env
)

model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

reloaded_model = mlflow.fastai.load_model(model_uri=model_uri)

Expand Down
3 changes: 2 additions & 1 deletion tests/gluon/test_gluon_model_export.py
Expand Up @@ -133,10 +133,11 @@ def test_model_log_load(gluon_model, model_data, model_path):

artifact_path = "model"
with mlflow.start_run():
mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path)
model_info = mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

# Loading Gluon model
model_loaded = mlflow.gluon.load_model(model_uri, ctx.cpu())
Expand Down