diff --git a/mlflow/store/artifact/utils/models.py b/mlflow/store/artifact/utils/models.py index 43f31610a41f4..171732585ee8e 100644 --- a/mlflow/store/artifact/utils/models.py +++ b/mlflow/store/artifact/utils/models.py @@ -3,7 +3,6 @@ import mlflow.tracking from mlflow.exceptions import MlflowException from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri -from mlflow.entities.model_registry.model_version_stages import ALL_STAGES _MODELS_URI_SUFFIX_LATEST = "latest" @@ -60,13 +59,11 @@ def _parse_model_uri(uri): if parts[1].isdigit(): # The suffix is a specific version, e.g. "models:/AdsModel1/123" return parts[0], int(parts[1]), None - elif parts[1] == _MODELS_URI_SUFFIX_LATEST: - # The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest" + elif parts[1].lower() == _MODELS_URI_SUFFIX_LATEST.lower(): + # The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest" return parts[0], None, None - elif parts[1] not in ALL_STAGES: - raise MlflowException(_improper_model_uri_msg(uri)) else: - # The suffix is a specific stage, e.g. "models:/AdsModel1/Production" + # The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production" return parts[0], None, parts[1] diff --git a/tests/store/artifact/utils/test_model_utils.py b/tests/store/artifact/utils/test_model_utils.py index e7631f8a28db5..39173d8004790 100644 --- a/tests/store/artifact/utils/test_model_utils.py +++ b/tests/store/artifact/utils/test_model_utils.py @@ -27,6 +27,8 @@ def test_parse_models_uri_with_version(uri, expected_name, expected_version): "uri, expected_name, expected_stage", [ ("models:/AdsModel1/Production", "AdsModel1", "Production"), + ("models:/AdsModel1/production", "AdsModel1", "production"), # case insensitive + ("models:/AdsModel1/pROduction", "AdsModel1", "pROduction"), # case insensitive ("models:/Ads Model 1/None", "Ads Model 1", "None"), ("models://scope:key@databricks/Ads Model 1/None", "Ads Model 1", "None"), ], @@ -42,6 +44,8 @@ def test_parse_models_uri_with_stage(uri, expected_name, expected_stage): "uri, expected_name", [ ("models:/AdsModel1/latest", "AdsModel1"), + ("models:/AdsModel1/Latest", "AdsModel1"), # case insensitive + ("models:/AdsModel1/LATEST", "AdsModel1"), # case insensitive ("models:/Ads Model 1/latest", "Ads Model 1"), ("models://scope:key@databricks/Ads Model 1/latest", "Ads Model 1"), ], @@ -60,8 +64,6 @@ def test_parse_models_uri_with_latest(uri, expected_name): "notmodels:/NameOfModel/StageName", # wrong scheme with stage "models:/", # no model name "models:/Name/Stage/0", # too many specifiers - "models:/Name/production", # should be 'Production' - "models:/Name/LATEST", # not lower case 'latest' "models:Name/Stage", # missing slash "models://Name/Stage", # hostnames are ignored, path too short ], @@ -119,3 +121,8 @@ def test_get_model_name_and_version_with_latest(): "20", ) mlflow_client_mock.assert_called_once_with("AdsModel1", None) + # Check that "latest" is case insensitive. + assert get_model_name_and_version(MlflowClient(), "models:/AdsModel1/lATest") == ( + "AdsModel1", + "20", + ) diff --git a/tests/store/model_registry/test_sqlalchemy_store.py b/tests/store/model_registry/test_sqlalchemy_store.py index e285639aab966..b11a7dc90bca4 100644 --- a/tests/store/model_registry/test_sqlalchemy_store.py +++ b/tests/store/model_registry/test_sqlalchemy_store.py @@ -347,6 +347,18 @@ def test_get_latest_versions(self): ), {"Production": 3}, ) + self.assertEqual( + self._extract_latest_by_stage( + self.store.get_latest_versions(name=name, stages=["production"]) + ), + {"Production": 3}, + ) # The stages are case insensitive. + self.assertEqual( + self._extract_latest_by_stage( + self.store.get_latest_versions(name=name, stages=["pROduction"]) + ), + {"Production": 3}, + ) # The stages are case insensitive. self.assertEqual( self._extract_latest_by_stage( self.store.get_latest_versions(name=name, stages=["None", "Production"])