Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specifying 'latest' in model URI to get the latest version of a model regardless of the stage #5027

Merged
merged 4 commits into from Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -87,9 +87,10 @@ public void testGetLatestModelVersions() throws IOException {

client.sendPatch("model-versions/update", mapper.makeUpdateModelVersion(modelName,
"1"));
// default stages (does not include "None")
// get the latest version of all stages
List<ModelVersion> modelVersion = client.getLatestVersions(modelName);
Assert.assertEquals(modelVersion.size(), 0);
Assert.assertEquals(modelVersion.size(), 1);
validateDetailedModelVersion(modelVersion.get(0), modelName, "None", "1");
client.sendPost("model-versions/transition-stage",
mapper.makeTransitionModelVersionStage(modelName, "1", "Staging"));
modelVersion = client.getLatestVersions(modelName);
Expand Down
3 changes: 2 additions & 1 deletion mlflow/store/artifact/databricks_models_artifact_repo.py
Expand Up @@ -31,7 +31,8 @@ class DatabricksModelsArtifactRepository(ArtifactRepository):
The artifact_uri is expected to be of the form
- `models:/<model_name>/<model_version>`
- `models:/<model_name>/<stage>` (refers to the latest model version in the given stage)
- `models://<profile>/<model_name>/<model_version or stage>`
- `models:/<model_name>/latest` (refers to the latest of all model versions)
- `models://<profile>/<model_name>/<model_version or stage or 'latest'>`

Note : This artifact repository is meant is to be instantiated by the ModelsArtifactRepository
when the client is pointing to a Databricks-hosted model registry.
Expand Down
1 change: 1 addition & 0 deletions mlflow/store/artifact/models_artifact_repo.py
Expand Up @@ -19,6 +19,7 @@ class ModelsArtifactRepository(ArtifactRepository):
Handles artifacts associated with a model version in the model registry via URIs of the form:
- `models:/<model_name>/<model_version>`
- `models:/<model_name>/<stage>` (refers to the latest model version in the given stage)
- `models:/<model_name>/latest` (refers to the latest of all model versions)
It is a light wrapper that resolves the artifact path to an absolute URI then instantiates
and uses the artifact repository for that URI.
"""
Expand Down
42 changes: 31 additions & 11 deletions mlflow/store/artifact/utils/models.py
Expand Up @@ -3,6 +3,9 @@
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES

_MODELS_URI_SUFFIX_LATEST = "latest"


def is_using_databricks_registry(uri):
Expand All @@ -13,24 +16,34 @@ def is_using_databricks_registry(uri):
def _improper_model_uri_msg(uri):
return (
"Not a proper models:/ URI: %s. " % uri
+ "Models URIs must be of the form 'models:/<model_name>/<version or stage>'."
+ "Models URIs must be of the form 'models:/<model_name>/suffix' "
+ "where suffix is a model version, stage, or the string '%s'." % _MODELS_URI_SUFFIX_LATEST
)


def _get_model_version_from_stage(client, name, stage):
latest = client.get_latest_versions(name, [stage])
def _get_latest_model_version(client, name, stage):
"""
Returns the latest version of the stage if stage is not None. Otherwise return the latest of all
versions.
"""
latest = client.get_latest_versions(name, None if stage is None else [stage])
if len(latest) == 0:
stage_str = "" if stage is None else " and stage '{stage}'".format(stage=stage)
raise MlflowException(
"No versions of model with name '{name}' and "
"stage '{stage}' found".format(name=name, stage=stage)
"No versions of model with name '{name}'{stage_str} found".format(
name=name, stage_str=stage_str
)
)
return latest[0].version
return max(map(lambda x: int(x.version), latest))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - why do we need to call int() here? I'm not opposed just for safety reasons, but x.version is already a number right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually model version is str: link. So it's safer to convert it into int here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it



def _parse_model_uri(uri):
"""
Returns (name, version, stage). Since a models:/ URI can only have one of {version, stage},
it will return (name, version, None) or (name, None, stage).
Returns (name, version, stage). Since a models:/ URI can only have one of
{version, stage, 'latest'}, it will return
- (name, version, None) to look for a specific version,
- (name, None, stage) to look for the latest version of a stage,
- (name, None, None) to look for the latest of all versions.
"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "models":
Expand All @@ -45,13 +58,20 @@ def _parse_model_uri(uri):
raise MlflowException(_improper_model_uri_msg(uri))

if parts[1].isdigit():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the added complexity, it may be good to have an example of each URI type in the branch so that it's clear exactly which case maps to which tuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return parts[0], int(parts[1]), None
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]


def get_model_name_and_version(client, models_uri):
(model_name, model_version, model_stage) = _parse_model_uri(models_uri)
if model_stage is not None:
model_version = _get_model_version_from_stage(client, model_name, model_stage)
return model_name, str(model_version)
if model_version is not None:
return model_name, str(model_version)
return model_name, str(_get_latest_model_version(client, model_name, model_stage))
2 changes: 1 addition & 1 deletion mlflow/store/model_registry/abstract_store.py
Expand Up @@ -116,7 +116,7 @@ def get_latest_versions(self, name, stages=None):

:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
pass
Expand Down
2 changes: 1 addition & 1 deletion mlflow/store/model_registry/rest_store.py
Expand Up @@ -184,7 +184,7 @@ def get_latest_versions(self, name, stages=None):

:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
req_body = message_to_json(GetLatestVersions(name=name, stages=stages))
Expand Down
7 changes: 3 additions & 4 deletions mlflow/store/model_registry/sqlalchemy_store.py
Expand Up @@ -5,6 +5,7 @@

from mlflow.entities.model_registry.model_version_stages import (
get_canonical_stage,
ALL_STAGES,
DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
STAGE_DELETED_INTERNAL,
STAGE_ARCHIVED,
Expand Down Expand Up @@ -432,17 +433,15 @@ def get_latest_versions(self, name, stages=None):

:param name: Registered model name.
:param stages: List of desired stages. If input list is None, return latest versions for
for 'Staging' and 'Production' stages.
each stage.
:return: List of :py:class:`mlflow.entities.model_registry.ModelVersion` objects.
"""
with self.ManagedSessionMaker() as session:
sql_registered_model = self._get_registered_model(session, name)
# Convert to RegisteredModel entity first and then extract latest_versions
latest_versions = sql_registered_model.to_mlflow_entity().latest_versions
if stages is None or len(stages) == 0:
expected_stages = set(
[get_canonical_stage(stage) for stage in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS]
)
expected_stages = set([get_canonical_stage(stage) for stage in ALL_STAGES])
else:
expected_stages = set([get_canonical_stage(stage) for stage in stages])
return [mv for mv in latest_versions if mv.current_stage in expected_stages]
Expand Down
72 changes: 71 additions & 1 deletion tests/store/artifact/utils/test_model_utils.py
@@ -1,7 +1,10 @@
import pytest
from unittest import mock

from mlflow.exceptions import MlflowException
from mlflow.store.artifact.utils.models import _parse_model_uri
from mlflow.store.artifact.utils.models import _parse_model_uri, get_model_name_and_version
from mlflow.tracking import MlflowClient
from mlflow.entities.model_registry import ModelVersion


@pytest.mark.parametrize(
Expand Down Expand Up @@ -35,17 +38,84 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
assert stage == expected_stage


@pytest.mark.parametrize(
"uri, expected_name",
[
("models:/AdsModel1/latest", "AdsModel1"),
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
],
)
def test_parse_models_uri_with_latest(uri, expected_name):
(name, version, stage) = _parse_model_uri(uri)
assert name == expected_name
assert version is None
assert stage is None


@pytest.mark.parametrize(
"uri",
[
"notmodels:/NameOfModel/12345", # wrong scheme with version
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:/Name/production", # should be 'Production'
"models:/Name/LATEST", # not lower case 'latest'
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
],
)
def test_parse_models_uri_invalid_input(uri):
with pytest.raises(MlflowException):
_parse_model_uri(uri)


def test_get_model_name_and_version_with_version():
with mock.patch.object(
MlflowClient, "get_latest_versions", return_value=[]
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/123") == (
"AdsModel1",
"123",
)
mlflow_client_mock.assert_not_called()


def test_get_model_name_and_version_with_stage():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(
name="mv2", version="15", creation_timestamp=124, current_stage="Production"
),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/Production") == (
"AdsModel1",
"15",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", ["Production"])


def test_get_model_name_and_version_with_latest():
with mock.patch.object(
MlflowClient,
"get_latest_versions",
return_value=[
ModelVersion(
name="mv1", version="10", creation_timestamp=123, current_stage="Production"
),
ModelVersion(name="mv3", version="20", creation_timestamp=125, current_stage="None"),
ModelVersion(name="mv2", version="15", creation_timestamp=124, current_stage="Staging"),
],
) as mlflow_client_mock:
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/latest") == (
"AdsModel1",
"20",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", None)
30 changes: 30 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Expand Up @@ -333,6 +333,26 @@ def test_get_latest_versions(self):
self._extract_latest_by_stage(rmd4.latest_versions),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=[])),
{"None": 1, "Production": 3, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["Production"])
),
{"Production": 3},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["None", "Production"])
),
{"None": 1, "Production": 3},
)

# delete latest Production, and should point to previous one
self.store.delete_model_version(name=mv3.name, version=mv3.version)
Expand All @@ -341,6 +361,16 @@ def test_get_latest_versions(self):
self._extract_latest_by_stage(rmd5.latest_versions),
{"None": 1, "Production": 2, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(self.store.get_latest_versions(name=name, stages=None)),
{"None": 1, "Production": 2, "Staging": 4},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["Production"])
),
{"Production": 2},
)

def test_set_registered_model_tag(self):
name1 = "SetRegisteredModelTag_TestMod"
Expand Down
4 changes: 2 additions & 2 deletions tests/tracking/test_model_registry.py
Expand Up @@ -484,8 +484,8 @@ def get_latest(stages):
assert {"None": "7"} == get_latest(["None"])
assert {"Staging": "6"} == get_latest(["Staging"])
assert {"None": "7", "Staging": "6"} == get_latest(["None", "Staging"])
assert {"Production": "4", "Staging": "6"} == get_latest(None)
assert {"Production": "4", "Staging": "6"} == get_latest([])
assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest(None)
assert {"Production": "4", "Staging": "6", "Archived": "3", "None": "7"} == get_latest([])


def test_delete_model_version_flow(mlflow_client, backend_store_uri):
Expand Down