Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Model from log_model #5230

Merged
merged 12 commits into from Jan 12, 2022
2 changes: 1 addition & 1 deletion mlflow/keras.py
Expand Up @@ -376,7 +376,7 @@ def log_model(
with mlflow.start_run() as run:
mlflow.keras.log_model(keras_model, "models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.keras,
keras_model=keras_model,
Expand Down
2 changes: 1 addition & 1 deletion mlflow/lightgbm.py
Expand Up @@ -264,7 +264,7 @@ def log_model(
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `lightgbm.Booster.save_model`_ method.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.lightgbm,
registered_model_name=registered_model_name,
Expand Down
38 changes: 36 additions & 2 deletions mlflow/models/model.py
@@ -1,3 +1,4 @@
from collections import namedtuple
from datetime import datetime
import json
import logging
Expand All @@ -22,12 +23,29 @@
"Logging model metadata to the tracking server has failed, possibly due older "
"server version. The model artifacts have been logged successfully under %s. "
"In addition to exporting model artifacts, MLflow clients 1.7.0 and above "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"attempt to record model metadata to the tracking store. If logging to a "
"mlflow server via REST, consider upgrading the server version to MLflow "
"1.7.0 or above."
)


ModelInfo = namedtuple(
"ModelInfo",
[
"run_id",
"artifact_path",
"model_uri",
"utc_time_created",
"flavors",
"model_uuid",
"saved_input_example_info",
"signature",
"input_schema",
"output_schema",
],
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added example values for some attributes.


class Model(object):
"""
An MLflow Model that can support multiple model flavors. Provides APIs for implementing
Expand Down Expand Up @@ -109,6 +127,21 @@ def saved_input_example_info(self, value: Dict[str, Any]):
# pylint: disable=attribute-defined-outside-init
self._saved_input_example_info = value

def get_model_info(self) -> ModelInfo:
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
"""Create a ModelInfo instance that contains the model metadata."""
return ModelInfo(
run_id=self.run_id,
artifact_path=self.artifact_path,
model_uri="runs:/{}/{}".format(self.run_id, self.artifact_path),
liangz1 marked this conversation as resolved.
Show resolved Hide resolved
utc_time_created=self.utc_time_created,
flavors=self.flavors,
model_uuid=self.model_uuid,
saved_input_example_info=self.saved_input_example_info,
signature=self.signature,
input_schema=self.get_input_schema(),
output_schema=self.get_output_schema(),
)

def to_dict(self):
"""Serialize the model to a dictionary."""
res = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
Expand Down Expand Up @@ -223,3 +256,4 @@ def log(
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
2 changes: 1 addition & 1 deletion mlflow/pytorch/__init__.py
Expand Up @@ -283,7 +283,7 @@ def gen_data():
PyTorch logged models
"""
pickle_module = pickle_module or mlflow_pytorch_pickle_module
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.pytorch,
pytorch_model=pytorch_model,
Expand Down
2 changes: 1 addition & 1 deletion mlflow/sklearn/__init__.py
Expand Up @@ -371,7 +371,7 @@ def log_model(
# log model
mlflow.sklearn.log_model(sk_model, "sk_models")
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.sklearn,
sk_model=sk_model,
Expand Down
2 changes: 1 addition & 1 deletion mlflow/xgboost/__init__.py
Expand Up @@ -241,7 +241,7 @@ def log_model(
:param extra_pip_requirements: {{ extra_pip_requirements }}
:param kwargs: kwargs to pass to `xgboost.Booster.save_model`_ method.
"""
Model.log(
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.xgboost,
registered_model_name=registered_model_name,
Expand Down