Skip to content

Commit

Permalink
Enable mlflow-artifacts scheme as wrapper around http artifact scheme (
Browse files Browse the repository at this point in the history
…#5070)

* Enable mlflow-artifacts scheme as wrapper around http artifact scheme

Signed-off-by: Ben Wilson <benjamin.wilson@databricks.com>
  • Loading branch information
BenWilson2 committed Nov 24, 2021
1 parent 94db108 commit 776831a
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 1 deletion.
2 changes: 2 additions & 0 deletions mlflow/store/artifact/artifact_repository_registry.py
Expand Up @@ -13,6 +13,7 @@
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository
from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository

from mlflow.utils.uri import get_uri_scheme

Expand Down Expand Up @@ -88,6 +89,7 @@ def get_artifact_repository(self, artifact_uri):
_artifact_repository_registry.register("models", ModelsArtifactRepository)
for scheme in ["http", "https"]:
_artifact_repository_registry.register(scheme, HttpArtifactRepository)
_artifact_repository_registry.register("mlflow-artifacts", MlflowArtifactsRepository)

_artifact_repository_registry.register_entrypoints()

Expand Down
3 changes: 2 additions & 1 deletion mlflow/store/artifact/http_artifact_repo.py
Expand Up @@ -15,7 +15,8 @@ def __init__(self, artifact_uri):
self._session = requests.Session()

def __del__(self):
self._session.close()
if hasattr(self, "_session"):
self._session.close()

def log_artifact(self, local_file, artifact_path=None):
verify_artifact_path(artifact_path)
Expand Down
96 changes: 96 additions & 0 deletions mlflow/store/artifact/mlflow_artifacts_repo.py
@@ -0,0 +1,96 @@
from urllib.parse import urlparse
from collections import namedtuple

from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository
from mlflow.tracking._tracking_service.utils import get_tracking_uri
from mlflow.exceptions import MlflowException


def _parse_artifact_uri(artifact_uri):
ParsedURI = namedtuple("ParsedURI", "scheme host port path")
parsed_uri = urlparse(artifact_uri)
return ParsedURI(parsed_uri.scheme, parsed_uri.hostname, parsed_uri.port, parsed_uri.path)


def _check_if_host_is_numeric(hostname):
if hostname:
try:
float(hostname)
return True
except ValueError:
return False
else:
return False


def _validate_port_mapped_to_hostname(uri_parse):
# This check is to catch an mlflow-artifacts uri that has a port designated but no
# hostname specified. `urllib.parse.urlparse` will treat such a uri as a filesystem
# definition, mapping the provided port as a hostname value if this condition is not
# validated.
if uri_parse.host and _check_if_host_is_numeric(uri_parse.host) and not uri_parse.port:
raise MlflowException(
"The mlflow-artifacts uri was supplied with a port number: "
f"{uri_parse.host}, but no host was defined."
)


def _validate_uri_scheme(scheme):
allowable_schemes = {"http", "https"}
if scheme not in allowable_schemes:
raise MlflowException(
f"The configured tracking uri scheme: '{scheme}' is invalid for use with the proxy "
f"mlflow-artifact scheme. The allowed tracking schemes are: {allowable_schemes}"
)


class MlflowArtifactsRepository(HttpArtifactRepository):
"""Scheme wrapper around HttpArtifactRepository for mlflow-artifacts server functionality"""

def __init__(self, artifact_uri):

super().__init__(self.resolve_uri(artifact_uri))

@classmethod
def resolve_uri(cls, artifact_uri):

base_url = "/api/2.0/mlflow-artifacts/artifacts"
tracking_uri = get_tracking_uri()

track_parse = _parse_artifact_uri(tracking_uri)

uri_parse = _parse_artifact_uri(artifact_uri)

# Check to ensure that a port is present with no hostname
_validate_port_mapped_to_hostname(uri_parse)

# Check that tracking uri is http or https
_validate_uri_scheme(track_parse.scheme)

if uri_parse.path == "/": # root directory; build simple path
resolved = f"{base_url}{uri_parse.path}"
elif uri_parse.path == base_url: # for operations like list artifacts
resolved = base_url
else:
resolved = f"{base_url}{track_parse.path}{uri_parse.path.lstrip('/')}"

if uri_parse.host and uri_parse.port:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{uri_parse.host}:{uri_parse.port}{resolved}"
)
elif uri_parse.host and not uri_parse.port:
resolved_artifacts_uri = f"{track_parse.scheme}://{uri_parse.host}{resolved}"
elif not uri_parse.host and not uri_parse.port and uri_parse.path == track_parse.path:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{track_parse.host}:{track_parse.port}{resolved}"
)
elif not uri_parse.host and not uri_parse.port:
resolved_artifacts_uri = (
f"{track_parse.scheme}://{track_parse.host}:" f"{track_parse.port}{resolved}"
)
else:
raise MlflowException(
f"The supplied artifact uri {artifact_uri} could not be resolved."
)

return resolved_artifacts_uri.replace("///", "/").rstrip("/")

0 comments on commit 776831a

Please sign in to comment.