diff --git a/mlflow/models/model.py b/mlflow/models/model.py index 1ada7241bd406..806357c10ab4a 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -6,7 +6,7 @@ import os import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, Callable import mlflow from mlflow.exceptions import MlflowException @@ -42,7 +42,7 @@ def __init__( flavors=None, signature=None, # ModelSignature saved_input_example_info: Dict[str, Any] = None, - model_uuid=None, + model_uuid: Union[str, Callable, None] = lambda: uuid.uuid4().hex, **kwargs, ): # store model id instead of run_id and path to avoid confusion when model gets exported @@ -54,7 +54,7 @@ def __init__( self.flavors = flavors if flavors is not None else {} self.signature = signature self.saved_input_example_info = saved_input_example_info - self.model_uuid = uuid.uuid4().hex if model_uuid is None else model_uuid + self.model_uuid = model_uuid() if callable(model_uuid) else model_uuid self.__dict__.update(kwargs) def __eq__(self, other): @@ -133,10 +133,13 @@ def from_dict(cls, model_dict): from .signature import ModelSignature + model_dict = model_dict.copy() if "signature" in model_dict and isinstance(model_dict["signature"], dict): - model_dict = model_dict.copy() model_dict["signature"] = ModelSignature.from_dict(model_dict["signature"]) + if "model_uuid" not in model_dict: + model_dict["model_uuid"] = None + return cls(**model_dict) @classmethod diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 3f058c765b87e..131a70004b54a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -164,3 +164,31 @@ def test_model_log_with_input_example_succeeds(): # date column will get deserialized into string input_example["d"] = input_example["d"].apply(lambda x: x.isoformat()) assert x.equals(input_example) + + +def _is_valid_uuid(val): + import uuid + + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def test_model_uuid(): + m = Model() + assert m.model_uuid is not None + assert _is_valid_uuid(m.model_uuid) + + m2 = Model() + assert m.model_uuid != m2.model_uuid + + m_dict = m.to_dict() + assert m_dict["model_uuid"] == m.model_uuid + m3 = Model.from_dict(m_dict) + assert m3.model_uuid == m.model_uuid + + m_dict.pop("model_uuid") + m4 = Model.from_dict(m_dict) + assert m4.model_uuid is None diff --git a/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py b/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py index 6e9baca1e20df..8b4f395096bc8 100644 --- a/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py +++ b/tests/pyfunc/test_model_export_with_loader_module_and_data_path.py @@ -556,26 +556,6 @@ def test_column_schema_enforcement_no_col_names(): assert pyfunc_model.predict(d).equals(pd.DataFrame(d)) -def _is_valid_uuid(val): - import uuid - - try: - uuid.UUID(str(val)) - return True - except ValueError: - return False - - -def test_model_uuid(): - m = Model() - assert m.model_uuid is not None - assert _is_valid_uuid(m.model_uuid) - m_dict = m.to_dict() - assert m_dict["model_uuid"] == m.model_uuid - m2 = Model.from_dict(m_dict) - assert m2.model_uuid == m.model_uuid - - def test_tensor_schema_enforcement_no_col_names(): m = Model() input_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 3))])