Skip to content

Commit

Permalink
Support specifying 'latest' in model URI to get the latest version of…
Browse files Browse the repository at this point in the history
… a model regardless of the stage (#5027)

* Support specifying 'latest' in model URI to get the latest version of a model regardless of the stage

Signed-off-by: Chenran Li <chenran.li@databricks.com>
  • Loading branch information
lichenran1234 committed Jan 10, 2022
1 parent 964f5ab commit d3ddd59
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 23 deletions.
Expand Up @@ -87,9 +87,10 @@ public void testGetLatestModelVersions() throws IOException {

client.sendPatch("model-versions/update", mapper.makeUpdateModelVersion(modelName,
"1"));
// default stages (does not include "None")
// get the latest version of all stages
List<ModelVersion> modelVersion = client.getLatestVersions(modelName);
Assert.assertEquals(modelVersion.size(), 0);
Assert.assertEquals(modelVersion.size(), 1);
validateDetailedModelVersion(modelVersion.get(0), modelName, "None", "1");
client.sendPost("model-versions/transition-stage",
mapper.makeTransitionModelVersionStage(modelName, "1", "Staging"));
modelVersion = client.getLatestVersions(modelName);
Expand Down
3 changes: 2 additions & 1 deletion mlflow/store/artifact/databricks_models_artifact_repo.py
Expand Up @@ -31,7 +31,8 @@ class DatabricksModelsArtifactRepository(ArtifactRepository):
The artifact_uri is expected to be of the form
- `models:/<model_name>/<model_version>`
- `models:/<model_name>/<stage>` (refers to the latest model version in the given stage)
- `models://<profile>/<model_name>/<model_version or stage>`
- `models:/<model_name>/latest` (refers to the latest of all model versions)
- `models://<profile>/<model_name>/<model_version or stage or 'latest'>`
Note : This artifact repository is meant is to be instantiated by the ModelsArtifactRepository
when the client is pointing to a Databricks-hosted model registry.
Expand Down
1 change: 1 addition & 0 deletions mlflow/store/artifact/models_artifact_repo.py
Expand Up @@ -19,6 +19,7 @@ class ModelsArtifactRepository(ArtifactRepository):
Handles artifacts associated with a model version in the model registry via URIs of the form:
- `models:/<model_name>/<model_version>`
- `models:/<model_name>/<stage>` (refers to the latest model version in the given stage)
- `models:/<model_name>/latest` (refers to the latest of all model versions)
It is a light wrapper that resolves the artifact path to an absolute URI then instantiates
and uses the artifact repository for that URI.
"""
Expand Down
42 changes: 31 additions & 11 deletions mlflow/store/artifact/utils/models.py
Expand Up @@ -3,6 +3,9 @@
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES

_MODELS_URI_SUFFIX_LATEST = "latest"


def is_using_databricks_registry(uri):
Expand All @@ -13,24 +16,34 @@ def is_using_databricks_registry(uri):
def _improper_model_uri_msg(uri):
return (
"Not a proper models:/ URI: %s. " % uri
+ "Models URIs must be of the form 'models:/<model_name>/<version or stage>'."
+ "Models URIs must be of the form 'models:/<model_name>/suffix' "
+ "where suffix is a model version, stage, or the string '%s'." % _MODELS_URI_SUFFIX_LATEST
)


def _get_model_version_from_stage(client, name, stage):
latest = client.get_latest_versions(name, [stage])
def _get_latest_model_version(client, name, stage):
"""
Returns the latest version of the stage if stage is not None. Otherwise return the latest of all
versions.
"""
latest = client.get_latest_versions(name, None if stage is None else [stage])
if len(latest) == 0:
stage_str = "" if stage is None else " and stage '{stage}'".format(stage=stage)
raise MlflowException(
"No versions of model with name '{name}' and "
"stage '{stage}' found".format(name=name, stage=stage)
"No versions of model with name '{name}'{stage_str} found".format(
name=name, stage_str=stage_str
)
)
return latest[0].version
return max(map(lambda x: int(x.version), latest))


def _parse_model_uri(uri):
"""
Returns (name, version, stage). Since a models:/ URI can only have one of {version, stage},
it will return (name, version, None) or (name, None, stage).
Returns (name, version, stage). Since a models:/ URI can only have one of
{version, stage, 'latest'}, it will return
- (name, version, None) to look for a specific version,
- (name, None, stage) to look for the latest version of a stage,
- (name, None, None) to look for the latest of all versions.
"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "models":
Expand All @@ -45,13 +58,20 @@ def _parse_model_uri(uri):
raise MlflowException(_improper_model_uri_msg(uri))

if parts[1].isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return parts[0], int(parts[1]), None
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]


def get_model_name_and_version(client, models_uri):
(model_name, model_version, model_stage) = _parse_model_uri(models_uri)
if model_stage is not None:
model_version = _get_model_version_from_stage(client, model_name, model_stage)
return model_name, str(model_version)
if model_version is not None:
return model_name, str(model_version)
return model_name, str(_get_latest_model_version(client, model_name, model_stage))
2 changes: 1 addition & 1 deletion mlflow/store/model_registry/abstract_store.py
Expand Up @@ -116,7 +116,7 @@ def get_latest_versions(self, name, stages=None):
:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
pass
Expand Down
2 changes: 1 addition & 1 deletion mlflow/store/model_registry/rest_store.py
Expand Up @@ -190,7 +190,7 @@ def get_latest_versions(self, name, stages=None):
:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
req_body = message_to_json(GetLatestVersions(name=name, stages=stages))
Expand Down
7 changes: 3 additions & 4 deletions mlflow/store/model_registry/sqlalchemy_store.py
Expand Up @@ -5,6 +5,7 @@

from mlflow.entities.model_registry.model_version_stages import (
get_canonical_stage,
ALL_STAGES,
DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
STAGE_DELETED_INTERNAL,
STAGE_ARCHIVED,
Expand Down Expand Up @@ -432,17 +433,15 @@ def get_latest_versions(self, name, stages=None):
:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
with self.ManagedSessionMaker() as session:
sql_registered_model = self._get_registered_model(session, name)
# Convert to RegisteredModel entity first and then extract latest_versions
latest_versions = sql_registered_model.to_mlflow_entity().latest_versions
if stages is None or len(stages) == 0:
expected_stages = set(
[get_canonical_stage(stage) for stage in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS]
)
expected_stages = set([get_canonical_stage(stage) for stage in ALL_STAGES])
else:
expected_stages = set([get_canonical_stage(stage) for stage in stages])
return [mv for mv in latest_versions if mv.current_stage in expected_stages]
Expand Down
72 changes: 71 additions & 1 deletion tests/store/artifact/utils/test_model_utils.py
@@ -1,7 +1,10 @@
import pytest
from unittest import mock

from mlflow.exceptions import MlflowException
from mlflow.store.artifact.utils.models import _parse_model_uri
from mlflow.store.artifact.utils.models import _parse_model_uri, get_model_name_and_version
from mlflow.tracking import MlflowClient
from mlflow.entities.model_registry import ModelVersion


@pytest.mark.parametrize(
Expand Down Expand Up @@ -35,17 +38,84 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
assert stage == expected_stage


@pytest.mark.parametrize(
"uri, expected_name",
[
("models:/AdsModel1/latest", "AdsModel1"),
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
],
)
def test_parse_models_uri_with_latest(uri, expected_name):
(name, version, stage) = _parse_model_uri(uri)
assert name == expected_name
assert version is None
assert stage is None


@pytest.mark.parametrize(
"uri",
[
"notmodels:/NameOfModel/12345", # wrong scheme with version
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:/Name/production", # should be 'Production'
"models:/Name/LATEST", # not lower case 'latest'
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
],
)
def test_parse_models_uri_invalid_input(uri):
with pytest.raises(MlflowException, match="Not a proper models"):
_parse_model_uri(uri)


def test_get_model_name_and_version_with_version():
with mock.patch.object(
MlflowClient, "get_latest_versions", return_value=[]
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/123") == (
"AdsModel1",
"123",
)
mlflow_client_mock.assert_not_called()


def test_get_model_name_and_version_with_stage():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(
name="mv2", version="15", creation_timestamp=124, current_stage="Production"
),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/Production") == (
"AdsModel1",
"15",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", ["Production"])


def test_get_model_name_and_version_with_latest():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(name="mv3", version="20", creation_timestamp=125, current_stage="None"),
ModelVersion(name="mv2", version="15", creation_timestamp=124, current_stage="Staging"),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/latest") == (
"AdsModel1",
"20",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", None)
30 changes: 30 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Expand Up @@ -333,6 +333,26 @@ def test_get_latest_versions(self):
self._extract_latest_by_stage(rmd4.latest_versions),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=[])),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["Production"])
),
{"Production": 3},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["None", "Production"])
),
{"None": 1, "Production": 3},
)

# delete latest Production, and should point to previous one
self.store.delete_model_version(name=mv3.name, version=mv3.version)
Expand All @@ -341,6 +361,16 @@ def test_get_latest_versions(self):
self._extract_latest_by_stage(rmd5.latest_versions),
{"None": 1, "Production": 2, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)),
{"None": 1, "Production": 2, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["Production"])
),
{"Production": 2},
)

def test_set_registered_model_tag(self):
name1 = "SetRegisteredModelTag_TestMod"
Expand Down
4 changes: 2 additions & 2 deletions tests/tracking/test_model_registry.py
Expand Up @@ -489,8 +489,8 @@ def get_latest(stages):
assert {"None": "7"} == get_latest(["None"])
assert {"Staging": "6"} == get_latest(["Staging"])
assert {"None": "7", "Staging": "6"} == get_latest(["None", "Staging"])
assert {"Production": "4", "Staging": "6"} == get_latest(None)
assert {"Production": "4", "Staging": "6"} == get_latest([])
assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest(None)
assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest([])


def test_delete_model_version_flow(mlflow_client, backend_store_uri):
Expand Down

0 comments on commit d3ddd59

Please sign in to comment.