Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of stages in models URI being case-sensitive #5312

Merged
merged 4 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions mlflow/entities/model_registry/model_version_stages.py
Expand Up @@ -9,6 +9,7 @@
STAGE_DELETED_INTERNAL = "Deleted_Internal"

ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
ALL_STAGES_IN_LOWER_CASE = [stage.lower() for stage in ALL_STAGES]
DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS = [STAGE_STAGING, STAGE_PRODUCTION]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}

Expand Down
7 changes: 4 additions & 3 deletions mlflow/store/artifact/utils/models.py
Expand Up @@ -3,7 +3,7 @@
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES_IN_LOWER_CASE

_MODELS_URI_SUFFIX_LATEST = "latest"

Expand Down Expand Up @@ -63,10 +63,11 @@ def _parse_model_uri(uri):
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
elif parts[1].lower() not in ALL_STAGES_IN_LOWER_CASE:
# Now the suffix should be a specific stage (case insensitive). If not, throw an exception.
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
# The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]


Expand Down
3 changes: 2 additions & 1 deletion tests/store/artifact/utils/test_model_utils.py
Expand Up @@ -27,6 +27,8 @@ def test_parse_models_uri_with_version(uri, expected_name, expected_version):
"uri, expected_name, expected_stage",
[
("models:/AdsModel1/Production", "AdsModel1", "Production"),
("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive
("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive
("models:/Ads Model 1/None", "Ads Model 1", "None"),
("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"),
],
Expand Down Expand Up @@ -60,7 +62,6 @@ def test_parse_models_uri_with_latest(uri, expected_name):
"notmodels:/NameOfModel/StageName", # wrong scheme with stage
"models:/", # no model name
"models:/Name/Stage/0", # too many specifiers
"models:/Name/production", # should be 'Production'
"models:/Name/LATEST", # not lower case 'latest'
harupy marked this conversation as resolved.
Show resolved Hide resolved
"models:Name/Stage", # missing slash
"models://Name/Stage", # hostnames are ignored, path too short
Expand Down
12 changes: 12 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Expand Up @@ -347,6 +347,18 @@ def test_get_latest_versions(self):
),
{"Production": 3},
)
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["production"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["pROduction"])
),
{"Production": 3},
) # The stages are case insensitive.
self.assertEqual(
self._extract_latest_by_stage(
self.store.get_latest_versions(name=name, stages=["None", "Production"])
Expand Down