New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support specifying 'latest' in model URI to get the latest version of a model regardless of the stage #5027
Changes from all commits
c75afe4
3500f59
0bfbf1d
d02fe3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,9 @@ | |
import mlflow.tracking | ||
from mlflow.exceptions import MlflowException | ||
from mlflow.utils.uri import get_databricks_profile_uri_from_artifact_uri, is_databricks_uri | ||
from mlflow.entities.model_registry.model_version_stages import ALL_STAGES | ||
|
||
_MODELS_URI_SUFFIX_LATEST = "latest" | ||
|
||
|
||
def is_using_databricks_registry(uri): | ||
|
@@ -13,24 +16,34 @@ def is_using_databricks_registry(uri): | |
def _improper_model_uri_msg(uri): | ||
return ( | ||
"Not a proper models:/ URI: %s. " % uri | ||
+ "Models URIs must be of the form 'models:/<model_name>/<version or stage>'." | ||
+ "Models URIs must be of the form 'models:/<model_name>/suffix' " | ||
+ "where suffix is a model version, stage, or the string '%s'." % _MODELS_URI_SUFFIX_LATEST | ||
) | ||
|
||
|
||
def _get_model_version_from_stage(client, name, stage): | ||
latest = client.get_latest_versions(name, [stage]) | ||
def _get_latest_model_version(client, name, stage): | ||
""" | ||
Returns the latest version of the stage if stage is not None. Otherwise return the latest of all | ||
versions. | ||
""" | ||
latest = client.get_latest_versions(name, None if stage is None else [stage]) | ||
if len(latest) == 0: | ||
stage_str = "" if stage is None else " and stage '{stage}'".format(stage=stage) | ||
raise MlflowException( | ||
"No versions of model with name '{name}' and " | ||
"stage '{stage}' found".format(name=name, stage=stage) | ||
"No versions of model with name '{name}'{stage_str} found".format( | ||
name=name, stage_str=stage_str | ||
) | ||
) | ||
return latest[0].version | ||
return max(map(lambda x: int(x.version), latest)) | ||
|
||
|
||
def _parse_model_uri(uri): | ||
""" | ||
Returns (name, version, stage). Since a models:/ URI can only have one of {version, stage}, | ||
it will return (name, version, None) or (name, None, stage). | ||
Returns (name, version, stage). Since a models:/ URI can only have one of | ||
{version, stage, 'latest'}, it will return | ||
- (name, version, None) to look for a specific version, | ||
- (name, None, stage) to look for the latest version of a stage, | ||
- (name, None, None) to look for the latest of all versions. | ||
""" | ||
parsed = urllib.parse.urlparse(uri) | ||
if parsed.scheme != "models": | ||
|
@@ -45,13 +58,20 @@ def _parse_model_uri(uri): | |
raise MlflowException(_improper_model_uri_msg(uri)) | ||
|
||
if parts[1].isdigit(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the added complexity, it may be good to have an example of each URI type in the branch so that it's clear exactly which case maps to which tuple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, thanks! |
||
# The suffix is a specific version, e.g. "models:/AdsModel1/123" | ||
return parts[0], int(parts[1]), None | ||
elif parts[1] == _MODELS_URI_SUFFIX_LATEST: | ||
# The suffix is exactly the 'latest' string, e.g. "models:/AdsModel1/latest" | ||
return parts[0], None, None | ||
elif parts[1] not in ALL_STAGES: | ||
raise MlflowException(_improper_model_uri_msg(uri)) | ||
else: | ||
# The suffix is a specific stage, e.g. "models:/AdsModel1/Production" | ||
return parts[0], None, parts[1] | ||
|
||
|
||
def get_model_name_and_version(client, models_uri): | ||
(model_name, model_version, model_stage) = _parse_model_uri(models_uri) | ||
if model_stage is not None: | ||
model_version = _get_model_version_from_stage(client, model_name, model_stage) | ||
return model_name, str(model_version) | ||
if model_version is not None: | ||
return model_name, str(model_version) | ||
return model_name, str(_get_latest_model_version(client, model_name, model_stage)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious - why do we need to call
int()
here? I'm not opposed just for safety reasons, butx.version
is already a number right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually model version is str: link. So it's safer to convert it into int here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it