diff --git a/mlflow/java/client/src/test/java/org/mlflow/tracking/ModelRegistryMlflowClientTest.java b/mlflow/java/client/src/test/java/org/mlflow/tracking/ModelRegistryMlflowClientTest.java index 03acb548c428d..361da5ab6f11e 100644 --- a/mlflow/java/client/src/test/java/org/mlflow/tracking/ModelRegistryMlflowClientTest.java +++ b/mlflow/java/client/src/test/java/org/mlflow/tracking/ModelRegistryMlflowClientTest.java @@ -87,9 +87,10 @@ public void testGetLatestModelVersions() throws IOException { client.sendPatch("model-versions/update", mapper.makeUpdateModelVersion(modelName, "1")); - // default stages (does not include "None") + // get the latest version of all stages List modelVersion = client.getLatestVersions(modelName); - Assert.assertEquals(modelVersion.size(), 0); + Assert.assertEquals(modelVersion.size(), 1); + validateDetailedModelVersion(modelVersion.get(0), modelName, "None", "1"); client.sendPost("model-versions/transition-stage", mapper.makeTransitionModelVersionStage(modelName, "1", "Staging")); modelVersion = client.getLatestVersions(modelName); diff --git a/mlflow/store/artifact/databricks_models_artifact_repo.py b/mlflow/store/artifact/databricks_models_artifact_repo.py index 4f04e0f7f86cf..8b2fee170ffc0 100644 --- a/mlflow/store/artifact/databricks_models_artifact_repo.py +++ b/mlflow/store/artifact/databricks_models_artifact_repo.py @@ -31,7 +31,8 @@ class DatabricksModelsArtifactRepository(ArtifactRepository): The artifact_uri is expected to be of the form - `models://` - `models://` (refers to the latest model version in the given stage) - - `models:////` + - `models://latest` (refers to the latest of all model versions) + - `models:////` Note : This artifact repository is meant is to be instantiated by the ModelsArtifactRepository when the client is pointing to a Databricks-hosted model registry. diff --git a/mlflow/store/artifact/models_artifact_repo.py b/mlflow/store/artifact/models_artifact_repo.py index 3f0dcbcb708c0..50e505b64a191 100644 --- a/mlflow/store/artifact/models_artifact_repo.py +++ b/mlflow/store/artifact/models_artifact_repo.py @@ -19,6 +19,7 @@ class ModelsArtifactRepository(ArtifactRepository): Handles artifacts associated with a model version in the model registry via URIs of the form: - `models://` - `models://` (refers to the latest model version in the given stage) + - `models://latest` (refers to the latest of all model versions) It is a light wrapper that resolves the artifact path to an absolute URI then instantiates and uses the artifact repository for that URI. """ diff --git a/mlflow/store/artifact/utils/models.py b/mlflow/store/artifact/utils/models.py index 97e4e231dfd42..43f31610a41f4 100644 --- a/mlflow/store/artifact/utils/models.py +++ b/mlflow/store/artifact/utils/models.py @@ -3,6 +3,9 @@ import mlflow.tracking from mlflow.exceptions import MlflowException from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri +from mlflow.entities.model_registry.model_version_stages import ALL_STAGES + +_MODELS_URI_SUFFIX_LATEST = "latest" def is_using_databricks_registry(uri): @@ -13,24 +16,34 @@ def is_using_databricks_registry(uri): def _improper_model_uri_msg(uri): return ( "Not a proper models:/ URI: %s. " % uri - + "Models URIs must be of the form 'models://'." + + "Models URIs must be of the form 'models://suffix' " + + "where suffix is a model version, stage, or the string '%s'." % _MODELS_URI_SUFFIX_LATEST ) -def _get_model_version_from_stage(client, name, stage): - latest = client.get_latest_versions(name, [stage]) +def _get_latest_model_version(client, name, stage): + """ + Returns the latest version of the stage if stage is not None. Otherwise return the latest of all + versions. + """ + latest = client.get_latest_versions(name, None if stage is None else [stage]) if len(latest) == 0: + stage_str = "" if stage is None else " and stage '{stage}'".format(stage=stage) raise MlflowException( - "No versions of model with name '{name}' and " - "stage '{stage}' found".format(name=name, stage=stage) + "No versions of model with name '{name}'{stage_str} found".format( + name=name, stage_str=stage_str + ) ) - return latest[0].version + return max(map(lambda x: int(x.version), latest)) def _parse_model_uri(uri): """ - Returns (name, version, stage). Since a models:/ URI can only have one of {version, stage}, - it will return (name, version, None) or (name, None, stage). + Returns (name, version, stage). Since a models:/ URI can only have one of + {version, stage, 'latest'}, it will return + - (name, version, None) to look for a specific version, + - (name, None, stage) to look for the latest version of a stage, + - (name, None, None) to look for the latest of all versions. """ parsed = urllib.parse.urlparse(uri) if parsed.scheme != "models": @@ -45,13 +58,20 @@ def _parse_model_uri(uri): raise MlflowException(_improper_model_uri_msg(uri)) if parts[1].isdigit(): + # The suffix is a specific version, e.g. "models:/AdsModel1/123" return parts[0], int(parts[1]), None + elif parts[1] == _MODELS_URI_SUFFIX_LATEST: + # The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest" + return parts[0], None, None + elif parts[1] not in ALL_STAGES: + raise MlflowException(_improper_model_uri_msg(uri)) else: + # The suffix is a specific stage, e.g. "models:/AdsModel1/Production" return parts[0], None, parts[1] def get_model_name_and_version(client, models_uri): (model_name, model_version, model_stage) = _parse_model_uri(models_uri) - if model_stage is not None: - model_version = _get_model_version_from_stage(client, model_name, model_stage) - return model_name, str(model_version) + if model_version is not None: + return model_name, str(model_version) + return model_name, str(_get_latest_model_version(client, model_name, model_stage)) diff --git a/mlflow/store/model_registry/abstract_store.py b/mlflow/store/model_registry/abstract_store.py index 2348e7319179e..6c21f4419a40c 100644 --- a/mlflow/store/model_registry/abstract_store.py +++ b/mlflow/store/model_registry/abstract_store.py @@ -116,7 +116,7 @@ def get_latest_versions(self, name, stages=None): :param name: Registered model name. :param stages: List of desired stages. If input list is None, return latest versions for - for 'Staging' and 'Production' stages. + each stage. :return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ pass diff --git a/mlflow/store/model_registry/rest_store.py b/mlflow/store/model_registry/rest_store.py index d6215066f2960..de383c8576e01 100644 --- a/mlflow/store/model_registry/rest_store.py +++ b/mlflow/store/model_registry/rest_store.py @@ -184,7 +184,7 @@ def get_latest_versions(self, name, stages=None): :param name: Registered model name. :param stages: List of desired stages. If input list is None, return latest versions for - for 'Staging' and 'Production' stages. + each stage. :return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ req_body = message_to_json(GetLatestVersions(name=name, stages=stages)) diff --git a/mlflow/store/model_registry/sqlalchemy_store.py b/mlflow/store/model_registry/sqlalchemy_store.py index bc39eccb4d6e3..c9534b1324712 100644 --- a/mlflow/store/model_registry/sqlalchemy_store.py +++ b/mlflow/store/model_registry/sqlalchemy_store.py @@ -5,6 +5,7 @@ from mlflow.entities.model_registry.model_version_stages import ( get_canonical_stage, + ALL_STAGES, DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS, STAGE_DELETED_INTERNAL, STAGE_ARCHIVED, @@ -432,7 +433,7 @@ def get_latest_versions(self, name, stages=None): :param name: Registered model name. :param stages: List of desired stages. If input list is None, return latest versions for - for 'Staging' and 'Production' stages. + each stage. :return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects. """ with self.ManagedSessionMaker() as session: @@ -440,9 +441,7 @@ def get_latest_versions(self, name, stages=None): # Convert to RegisteredModel entity first and then extract latest_versions latest_versions = sql_registered_model.to_mlflow_entity().latest_versions if stages is None or len(stages) == 0: - expected_stages = set( - [get_canonical_stage(stage) for stage in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS] - ) + expected_stages = set([get_canonical_stage(stage) for stage in ALL_STAGES]) else: expected_stages = set([get_canonical_stage(stage) for stage in stages]) return [mv for mv in latest_versions if mv.current_stage in expected_stages] diff --git a/tests/store/artifact/utils/test_model_utils.py b/tests/store/artifact/utils/test_model_utils.py index 496aaaadd2319..015b18252ab1d 100644 --- a/tests/store/artifact/utils/test_model_utils.py +++ b/tests/store/artifact/utils/test_model_utils.py @@ -1,7 +1,10 @@ import pytest +from unittest import mock from mlflow.exceptions import MlflowException -from mlflow.store.artifact.utils.models import _parse_model_uri +from mlflow.store.artifact.utils.models import _parse_model_uri, get_model_name_and_version +from mlflow.tracking import MlflowClient +from mlflow.entities.model_registry import ModelVersion @pytest.mark.parametrize( @@ -35,6 +38,21 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage): assert stage == expected_stage +@pytest.mark.parametrize( + "uri, expected_name", + [ + ("models:/AdsModel1/latest", "AdsModel1"), + ("models:/Ads Model 1/latest", "Ads Model 1"), + ("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"), + ], +) +def test_parse_models_uri_with_latest(uri, expected_name): + (name, version, stage) = _parse_model_uri(uri) + assert name == expected_name + assert version is None + assert stage is None + + @pytest.mark.parametrize( "uri", [ @@ -42,6 +60,8 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage): "notmodels:/NameOfModel/StageName", # wrong scheme with stage "models:/", # no model name "models:/Name/Stage/0", # too many specifiers + "models:/Name/production", # should be 'Production' + "models:/Name/LATEST", # not lower case 'latest' "models:Name/Stage", # missing slash "models://Name/Stage", # hostnames are ignored, path too short ], @@ -49,3 +69,53 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage): def test_parse_models_uri_invalid_input(uri): with pytest.raises(MlflowException): _parse_model_uri(uri) + + +def test_get_model_name_and_version_with_version(): + with mock.patch.object( + MlflowClient, "get_latest_versions", return_value=[] + ) as mlflow_client_mock: + assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/123") == ( + "AdsModel1", + "123", + ) + mlflow_client_mock.assert_not_called() + + +def test_get_model_name_and_version_with_stage(): + with mock.patch.object( + MlflowClient, + "get_latest_versions", + return_value=[ + ModelVersion( + name="mv1", version="10", creation_timestamp=123, current_stage="Production" + ), + ModelVersion( + name="mv2", version="15", creation_timestamp=124, current_stage="Production" + ), + ], + ) as mlflow_client_mock: + assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/Production") == ( + "AdsModel1", + "15", + ) + mlflow_client_mock.assert_called_once_with("AdsModel1", ["Production"]) + + +def test_get_model_name_and_version_with_latest(): + with mock.patch.object( + MlflowClient, + "get_latest_versions", + return_value=[ + ModelVersion( + name="mv1", version="10", creation_timestamp=123, current_stage="Production" + ), + ModelVersion(name="mv3", version="20", creation_timestamp=125, current_stage="None"), + ModelVersion(name="mv2", version="15", creation_timestamp=124, current_stage="Staging"), + ], + ) as mlflow_client_mock: + assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/latest") == ( + "AdsModel1", + "20", + ) + mlflow_client_mock.assert_called_once_with("AdsModel1", None) diff --git a/tests/store/model_registry/test_sqlalchemy_store.py b/tests/store/model_registry/test_sqlalchemy_store.py index 23b9118f32055..e285639aab966 100644 --- a/tests/store/model_registry/test_sqlalchemy_store.py +++ b/tests/store/model_registry/test_sqlalchemy_store.py @@ -333,6 +333,26 @@ def test_get_latest_versions(self): self._extract_latest_by_stage(rmd4.latest_versions), {"None": 1, "Production": 3, "Staging": 4}, ) + self.assertEqual( + self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)), + {"None": 1, "Production": 3, "Staging": 4}, + ) + self.assertEqual( + self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=[])), + {"None": 1, "Production": 3, "Staging": 4}, + ) + self.assertEqual( + self._extract_latest_by_stage( + self.store.get_latest_versions(name=name, stages=["Production"]) + ), + {"Production": 3}, + ) + self.assertEqual( + self._extract_latest_by_stage( + self.store.get_latest_versions(name=name, stages=["None", "Production"]) + ), + {"None": 1, "Production": 3}, + ) # delete latest Production, and should point to previous one self.store.delete_model_version(name=mv3.name, version=mv3.version) @@ -341,6 +361,16 @@ def test_get_latest_versions(self): self._extract_latest_by_stage(rmd5.latest_versions), {"None": 1, "Production": 2, "Staging": 4}, ) + self.assertEqual( + self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)), + {"None": 1, "Production": 2, "Staging": 4}, + ) + self.assertEqual( + self._extract_latest_by_stage( + self.store.get_latest_versions(name=name, stages=["Production"]) + ), + {"Production": 2}, + ) def test_set_registered_model_tag(self): name1 = "SetRegisteredModelTag_TestMod" diff --git a/tests/tracking/test_model_registry.py b/tests/tracking/test_model_registry.py index fc18e3bf12b61..648fe42ab1b7c 100644 --- a/tests/tracking/test_model_registry.py +++ b/tests/tracking/test_model_registry.py @@ -484,8 +484,8 @@ def get_latest(stages): assert {"None": "7"} == get_latest(["None"]) assert {"Staging": "6"} == get_latest(["Staging"]) assert {"None": "7", "Staging": "6"} == get_latest(["None", "Staging"]) - assert {"Production": "4", "Staging": "6"} == get_latest(None) - assert {"Production": "4", "Staging": "6"} == get_latest([]) + assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest(None) + assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest([]) def test_delete_model_version_flow(mlflow_client, backend_store_uri):