Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of stages in models URI being case-sensitive #5312

Merged
merged 4 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 3 additions & 6 deletions mlflow/store/artifact/utils/models.py
Expand Up @@ -3,7 +3,6 @@
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES

_MODELS_URI_SUFFIX_LATEST = "latest"

Expand Down Expand Up @@ -60,13 +59,11 @@ def _parse_model_uri(uri):
if parts[1].isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return parts[0], int(parts[1]), None
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
elif parts[1].lower() == _MODELS_URI_SUFFIX_LATEST.lower():
# The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]


Expand Down
11 changes: 9 additions & 2 deletions tests/store/artifact/utils/test_model_utils.py
Expand Up @@ -27,6 +27,8 @@ def test_parse_models_uri_with_version(uri, expected_name, expected_version):
"uri, expected_name, expected_stage",
[
("models:/AdsModel1/Production", "AdsModel1", "Production"),
("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive
("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive
("models:/Ads Model 1/None", "Ads Model 1", "None"),
("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"),
],
Expand All @@ -42,6 +44,8 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage):
"uri, expected_name",
[
("models:/AdsModel1/latest", "AdsModel1"),
("models:/AdsModel1/Latest", "AdsModel1"), # case insensitive
("models:/AdsModel1/LATEST", "AdsModel1"), # case insensitive
("models:/Ads Model 1/latest", "Ads Model 1"),
("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"),
],
Expand All @@ -60,8 +64,6 @@ def test_parse_models_uri_with_latest(uri, expected_name):
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:/Name/production", # should be 'Production'
"models:/Name/LATEST", # not lower case 'latest'
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
],
Expand Down Expand Up @@ -119,3 +121,8 @@ def test_get_model_name_and_version_with_latest():
"20",
)
mlflow_client_mock.assert_called_once_with("AdsModel1", None)
# Check that "latest" is case insensitive.
assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/lATest") == (
"AdsModel1",
"20",
)
12 changes: 12 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Expand Up @@ -347,6 +347,18 @@ def test_get_latest_versions(self):
),
{"Production": 3},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["production"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["pROduction"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["None", "Production"])
Expand Down