forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
77 lines (62 loc) · 2.92 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import urllib.parse
import mlflow.tracking
from mlflow.exceptions import MlflowException
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES
_MODELS_URI_SUFFIX_LATEST = "latest"
def is_using_databricks_registry(uri):
profile_uri = get_databricks_profile_uri_from_artifact_uri(uri) or mlflow.get_registry_uri()
return is_databricks_uri(profile_uri)
def _improper_model_uri_msg(uri):
return (
"Not a proper models:/ URI: %s. " % uri
+ "Models URIs must be of the form 'models:/<model_name>/suffix' "
+ "where suffix is a model version, stage, or the string '%s'." % _MODELS_URI_SUFFIX_LATEST
)
def _get_latest_model_version(client, name, stage):
"""
Returns the latest version of the stage if stage is not None. Otherwise return the latest of all
versions.
"""
latest = client.get_latest_versions(name, None if stage is None else [stage])
if len(latest) == 0:
stage_str = "" if stage is None else " and stage '{stage}'".format(stage=stage)
raise MlflowException(
"No versions of model with name '{name}'{stage_str} found".format(
name=name, stage_str=stage_str
)
)
return max(map(lambda x: int(x.version), latest))
def _parse_model_uri(uri):
"""
Returns (name, version, stage). Since a models:/ URI can only have one of
{version, stage, 'latest'}, it will return
- (name, version, None) to look for a specific version,
- (name, None, stage) to look for the latest version of a stage,
- (name, None, None) to look for the latest of all versions.
"""
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "models":
raise MlflowException(_improper_model_uri_msg(uri))
path = parsed.path
if not path.startswith("/") or len(path) <= 1:
raise MlflowException(_improper_model_uri_msg(uri))
parts = path[1:].split("/")
if len(parts) != 2 or parts[0].strip() == "":
raise MlflowException(_improper_model_uri_msg(uri))
if parts[1].isdigit():
# The suffix is a specific version, e.g. "models:/AdsModel1/123"
return parts[0], int(parts[1]), None
elif parts[1] == _MODELS_URI_SUFFIX_LATEST:
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest"
return parts[0], None, None
elif parts[1] not in ALL_STAGES:
raise MlflowException(_improper_model_uri_msg(uri))
else:
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production"
return parts[0], None, parts[1]
def get_model_name_and_version(client, models_uri):
(model_name, model_version, model_stage) = _parse_model_uri(models_uri)
if model_version is not None:
return model_name, str(model_version)
return model_name, str(_get_latest_model_version(client, model_name, model_stage))