Skip to content

Commit

Permalink
Only generate model uuid when logging model (#5167)
Browse files Browse the repository at this point in the history
* init

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* add test

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* update

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>

* updates

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Dec 16, 2021
1 parent 2668809 commit f9046f9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
11 changes: 7 additions & 4 deletions mlflow/models/model.py
Expand Up @@ -6,7 +6,7 @@
import os
import uuid

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, Callable

import mlflow
from mlflow.exceptions import MlflowException
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
flavors=None,
signature=None, # ModelSignature
saved_input_example_info: Dict[str, Any] = None,
model_uuid=None,
model_uuid: Union[str, Callable, None] = lambda: uuid.uuid4().hex,
**kwargs,
):
# store model id instead of run_id and path to avoid confusion when model gets exported
Expand All @@ -54,7 +54,7 @@ def __init__(
self.flavors = flavors if flavors is not None else {}
self.signature = signature
self.saved_input_example_info = saved_input_example_info
self.model_uuid = uuid.uuid4().hex if model_uuid is None else model_uuid
self.model_uuid = model_uuid() if callable(model_uuid) else model_uuid
self.__dict__.update(kwargs)

def __eq__(self, other):
Expand Down Expand Up @@ -133,10 +133,13 @@ def from_dict(cls, model_dict):

from .signature import ModelSignature

model_dict = model_dict.copy()
if "signature" in model_dict and isinstance(model_dict["signature"], dict):
model_dict = model_dict.copy()
model_dict["signature"] = ModelSignature.from_dict(model_dict["signature"])

if "model_uuid" not in model_dict:
model_dict["model_uuid"] = None

return cls(**model_dict)

@classmethod
Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_model.py
Expand Up @@ -164,3 +164,31 @@ def test_model_log_with_input_example_succeeds():
# date column will get deserialized into string
input_example["d"] = input_example["d"].apply(lambda x: x.isoformat())
assert x.equals(input_example)


def _is_valid_uuid(val):
import uuid

try:
uuid.UUID(str(val))
return True
except ValueError:
return False


def test_model_uuid():
m = Model()
assert m.model_uuid is not None
assert _is_valid_uuid(m.model_uuid)

m2 = Model()
assert m.model_uuid != m2.model_uuid

m_dict = m.to_dict()
assert m_dict["model_uuid"] == m.model_uuid
m3 = Model.from_dict(m_dict)
assert m3.model_uuid == m.model_uuid

m_dict.pop("model_uuid")
m4 = Model.from_dict(m_dict)
assert m4.model_uuid is None
20 changes: 0 additions & 20 deletions tests/pyfunc/test_model_export_with_loader_module_and_data_path.py
Expand Up @@ -556,26 +556,6 @@ def test_column_schema_enforcement_no_col_names():
assert pyfunc_model.predict(d).equals(pd.DataFrame(d))


def _is_valid_uuid(val):
import uuid

try:
uuid.UUID(str(val))
return True
except ValueError:
return False


def test_model_uuid():
m = Model()
assert m.model_uuid is not None
assert _is_valid_uuid(m.model_uuid)
m_dict = m.to_dict()
assert m_dict["model_uuid"] == m.model_uuid
m2 = Model.from_dict(m_dict)
assert m2.model_uuid == m.model_uuid


def test_tensor_schema_enforcement_no_col_names():
m = Model()
input_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 3))])
Expand Down

0 comments on commit f9046f9

Please sign in to comment.