Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Model from log_model #5230

Merged
merged 12 commits into from Jan 12, 2022
6 changes: 5 additions & 1 deletion mlflow/catboost.py
Expand Up @@ -226,8 +226,12 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `CatBoost.save_model`_ method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.catboost,
registered_model_name=registered_model_name,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/fastai/__init__.py
Expand Up @@ -252,6 +252,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -287,7 +291,7 @@ def main(epochs=5, learning_rate=0.01):

artifacts: ['model/MLmodel', 'model/conda.yaml', 'model/model.fastai']
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.fastai,
registered_model_name=registered_model_name,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/gluon/__init__.py
Expand Up @@ -307,6 +307,10 @@ def log_model(
by converting it to a list. Bytes are base64-encoded.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand All @@ -331,7 +335,7 @@ def log_model(
est.fit(train_data=train_data, epochs=100, val_data=validation_data)
mlflow.gluon.log_model(net, "model")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.gluon,
gluon_model=gluon_model,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/h2o.py
Expand Up @@ -210,8 +210,12 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``h2o.save_model`` method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.h2o,
registered_model_name=registered_model_name,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/keras.py
Expand Up @@ -361,6 +361,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``keras_model.save`` method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand All @@ -376,7 +380,7 @@ def log_model(
with mlflow.start_run() as run:
mlflow.keras.log_model(keras_model, "models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.keras,
keras_model=keras_model,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/lightgbm.py
Expand Up @@ -263,8 +263,12 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
Expand Down
5 changes: 5 additions & 0 deletions mlflow/mleap.py
Expand Up @@ -74,6 +74,11 @@ def log_model(
serialized to json using the Pandas split-oriented format. Bytes are
base64-encoded.

:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.



.. code-block:: python
Expand Down
43 changes: 41 additions & 2 deletions mlflow/models/model.py
@@ -1,3 +1,4 @@
from collections import namedtuple
from datetime import datetime
import json
import logging
Expand All @@ -22,12 +23,29 @@
"Logging model metadata to the tracking server has failed, possibly due older "
"server version. The model artifacts have been logged successfully under %s. "
"In addition to exporting model artifacts, MLflow clients 1.7.0 and above "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"1.7.0 or above."
)


ModelInfo = namedtuple(
"ModelInfo",
[
"run_id",
"artifact_path",
"model_uri",
"utc_time_created",
"flavors",
"model_uuid",
"saved_input_example_info",
"signature",
"input_schema",
"output_schema",
],
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added example values for some attributes.


class Model:
"""
An MLflow Model that can support multiple model flavors. Provides APIs for implementing
Expand Down Expand Up @@ -109,6 +127,21 @@ def saved_input_example_info(self, value: Dict[str, Any]):
# pylint: disable=attribute-defined-outside-init
self._saved_input_example_info = value

def get_model_info(self):
"""Create a ModelInfo instance that contains the model metadata."""
return ModelInfo(
run_id=self.run_id,
artifact_path=self.artifact_path,
model_uri="runs:/{}/{}".format(self.run_id, self.artifact_path),
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
utc_time_created=self.utc_time_created,
flavors=self.flavors,
model_uuid=self.model_uuid,
saved_input_example_info=self.saved_input_example_info,
signature=self.signature,
input_schema=self.get_input_schema(),
output_schema=self.get_output_schema(),
)

def to_dict(self):
"""Serialize the model to a dictionary."""
res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
Expand Down Expand Up @@ -203,6 +236,11 @@ def log(
waits for five minutes. Specify 0 or None to skip waiting.

:param kwargs: Extra args passed to the model flavor.

:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
with TempDir() as tmp:
local_path = tmp.path("model")
Expand All @@ -223,3 +261,4 @@ def log(
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
6 changes: 5 additions & 1 deletion mlflow/onnx.py
Expand Up @@ -362,8 +362,12 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.onnx,
onnx_model=onnx_model,
Expand Down
4 changes: 4 additions & 0 deletions mlflow/paddle/__init__.py
Expand Up @@ -371,6 +371,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand Down
6 changes: 5 additions & 1 deletion mlflow/prophet.py
Expand Up @@ -221,8 +221,12 @@ def log_model(
Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.prophet,
registered_model_name=registered_model_name,
Expand Down
4 changes: 4 additions & 0 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -1199,6 +1199,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/pytorch/__init__.py
Expand Up @@ -209,6 +209,10 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``torch.save`` method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -283,7 +287,7 @@ def gen_data():
PyTorch logged models
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/sklearn/__init__.py
Expand Up @@ -350,6 +350,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand All @@ -371,7 +375,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/spacy.py
Expand Up @@ -223,8 +223,12 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to ``spacy.save_model`` method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.spacy,
registered_model_name=registered_model_name,
Expand Down
5 changes: 5 additions & 0 deletions mlflow/spark.py
Expand Up @@ -170,6 +170,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.

.. code-block:: python
:caption: Example
Expand Down Expand Up @@ -259,6 +263,7 @@ def log_model(
registered_model_name,
await_registration_for,
)
return mlflow_model.get_model_info()


def _tmp_path(dfs_tmp):
Expand Down
6 changes: 5 additions & 1 deletion mlflow/statsmodels.py
Expand Up @@ -251,8 +251,12 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.statsmodels,
registered_model_name=registered_model_name,
Expand Down
4 changes: 4 additions & 0 deletions mlflow/tensorflow/__init__.py
Expand Up @@ -177,6 +177,10 @@ def log_model(
waits for five minutes. Specify 0 or None to skip waiting.
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
return Model.log(
artifact_path=artifact_path,
Expand Down
6 changes: 5 additions & 1 deletion mlflow/xgboost/__init__.py
Expand Up @@ -240,8 +240,12 @@ def log_model(
:param pip_requirements: {{ pip_requirements }}
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method.
:return: A `ModelInfo` namedtuple instance that contains the metadata of the logged model,
including: `run_id`, `artifact_path`, `model_uri`, `utc_time_created`, `flavors`,
`model_uuid`, `saved_input_example_info`, `signature`, `input_schema`, and
`output_schema`.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.xgboost,
registered_model_name=registered_model_name,
Expand Down
3 changes: 2 additions & 1 deletion tests/catboost/test_catboost_model_export.py
Expand Up @@ -119,8 +119,9 @@ def test_model_save_load(cb_model, model_path):
def test_log_model_logs_model_type(cb_model):
with mlflow.start_run():
artifact_path = "model"
mlflow.catboost.log_model(cb_model.model, artifact_path)
model_info = mlflow.catboost.log_model(cb_model.model, artifact_path)
model_uri = mlflow.get_artifact_uri(artifact_path)
assert model_info.model_uri == model_uri

flavor_conf = Model.load(model_uri).flavors["catboost"]
assert "model_type" in flavor_conf
Expand Down
3 changes: 2 additions & 1 deletion tests/fastai/test_fastai_model_export.py
Expand Up @@ -149,13 +149,14 @@ def test_model_log(fastai_model, model_path):
conda_env = os.path.join(tmp.path(), "conda_env.yaml")
_mlflow_conda_env(conda_env, additional_pip_deps=["fastai"])

mlflow.fastai.log_model(
model_info = mlflow.fastai.log_model(
fastai_learner=model, artifact_path=artifact_path, conda_env=conda_env
)

model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

reloaded_model = mlflow.fastai.load_model(model_uri=model_uri)

Expand Down
3 changes: 2 additions & 1 deletion tests/gluon/test_gluon_model_export.py
Expand Up @@ -133,10 +133,11 @@ def test_model_log_load(gluon_model, model_data, model_path):

artifact_path = "model"
with mlflow.start_run():
mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path)
model_info = mlflow.gluon.log_model(gluon_model, artifact_path=artifact_path)
model_uri = "runs:/{run_id}/{artifact_path}".format(
run_id=mlflow.active_run().info.run_id, artifact_path=artifact_path
)
assert model_info.model_uri == model_uri

# Loading Gluon model
model_loaded = mlflow.gluon.load_model(model_uri, ctx.cpu())
Expand Down