From 776831a35c6460581c5b6642795926ebcb8b8624 Mon Sep 17 00:00:00 2001 From: Ben Wilson <39283302+BenWilson2@users.noreply.github.com> Date: Wed, 24 Nov 2021 10:27:32 -0500 Subject: [PATCH] Enable mlflow-artifacts scheme as wrapper around http artifact scheme (#5070) * Enable mlflow-artifacts scheme as wrapper around http artifact scheme Signed-off-by: Ben Wilson --- .../artifact/artifact_repository_registry.py | 2 + mlflow/store/artifact/http_artifact_repo.py | 3 +- .../store/artifact/mlflow_artifacts_repo.py | 96 ++++++ .../artifact/test_mlflow_artifact_repo.py | 282 ++++++++++++++++++ 4 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 mlflow/store/artifact/mlflow_artifacts_repo.py create mode 100644 tests/store/artifact/test_mlflow_artifact_repo.py diff --git a/mlflow/store/artifact/artifact_repository_registry.py b/mlflow/store/artifact/artifact_repository_registry.py index bbd91d061b916..924af82a485d8 100644 --- a/mlflow/store/artifact/artifact_repository_registry.py +++ b/mlflow/store/artifact/artifact_repository_registry.py @@ -13,6 +13,7 @@ from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository from mlflow.store.artifact.sftp_artifact_repo import SFTPArtifactRepository from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository +from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository from mlflow.utils.uri import get_uri_scheme @@ -88,6 +89,7 @@ def get_artifact_repository(self, artifact_uri): _artifact_repository_registry.register("models", ModelsArtifactRepository) for scheme in ["http", "https"]: _artifact_repository_registry.register(scheme, HttpArtifactRepository) +_artifact_repository_registry.register("mlflow-artifacts", MlflowArtifactsRepository) _artifact_repository_registry.register_entrypoints() diff --git a/mlflow/store/artifact/http_artifact_repo.py b/mlflow/store/artifact/http_artifact_repo.py index ccdb9b03ffb11..f2dc2b4036ac5 100644 --- a/mlflow/store/artifact/http_artifact_repo.py +++ b/mlflow/store/artifact/http_artifact_repo.py @@ -15,7 +15,8 @@ def __init__(self, artifact_uri): self._session = requests.Session() def __del__(self): - self._session.close() + if hasattr(self, "_session"): + self._session.close() def log_artifact(self, local_file, artifact_path=None): verify_artifact_path(artifact_path) diff --git a/mlflow/store/artifact/mlflow_artifacts_repo.py b/mlflow/store/artifact/mlflow_artifacts_repo.py new file mode 100644 index 0000000000000..1d624d5ade894 --- /dev/null +++ b/mlflow/store/artifact/mlflow_artifacts_repo.py @@ -0,0 +1,96 @@ +from urllib.parse import urlparse +from collections import namedtuple + +from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository +from mlflow.tracking._tracking_service.utils import get_tracking_uri +from mlflow.exceptions import MlflowException + + +def _parse_artifact_uri(artifact_uri): + ParsedURI = namedtuple("ParsedURI", "scheme host port path") + parsed_uri = urlparse(artifact_uri) + return ParsedURI(parsed_uri.scheme, parsed_uri.hostname, parsed_uri.port, parsed_uri.path) + + +def _check_if_host_is_numeric(hostname): + if hostname: + try: + float(hostname) + return True + except ValueError: + return False + else: + return False + + +def _validate_port_mapped_to_hostname(uri_parse): + # This check is to catch an mlflow-artifacts uri that has a port designated but no + # hostname specified. `urllib.parse.urlparse` will treat such a uri as a filesystem + # definition, mapping the provided port as a hostname value if this condition is not + # validated. + if uri_parse.host and _check_if_host_is_numeric(uri_parse.host) and not uri_parse.port: + raise MlflowException( + "The mlflow-artifacts uri was supplied with a port number: " + f"{uri_parse.host}, but no host was defined." + ) + + +def _validate_uri_scheme(scheme): + allowable_schemes = {"http", "https"} + if scheme not in allowable_schemes: + raise MlflowException( + f"The configured tracking uri scheme: '{scheme}' is invalid for use with the proxy " + f"mlflow-artifact scheme. The allowed tracking schemes are: {allowable_schemes}" + ) + + +class MlflowArtifactsRepository(HttpArtifactRepository): + """Scheme wrapper around HttpArtifactRepository for mlflow-artifacts server functionality""" + + def __init__(self, artifact_uri): + + super().__init__(self.resolve_uri(artifact_uri)) + + @classmethod + def resolve_uri(cls, artifact_uri): + + base_url = "/api/2.0/mlflow-artifacts/artifacts" + tracking_uri = get_tracking_uri() + + track_parse = _parse_artifact_uri(tracking_uri) + + uri_parse = _parse_artifact_uri(artifact_uri) + + # Check to ensure that a port is present with no hostname + _validate_port_mapped_to_hostname(uri_parse) + + # Check that tracking uri is http or https + _validate_uri_scheme(track_parse.scheme) + + if uri_parse.path == "/": # root directory; build simple path + resolved = f"{base_url}{uri_parse.path}" + elif uri_parse.path == base_url: # for operations like list artifacts + resolved = base_url + else: + resolved = f"{base_url}{track_parse.path}{uri_parse.path.lstrip('/')}" + + if uri_parse.host and uri_parse.port: + resolved_artifacts_uri = ( + f"{track_parse.scheme}://{uri_parse.host}:{uri_parse.port}{resolved}" + ) + elif uri_parse.host and not uri_parse.port: + resolved_artifacts_uri = f"{track_parse.scheme}://{uri_parse.host}{resolved}" + elif not uri_parse.host and not uri_parse.port and uri_parse.path == track_parse.path: + resolved_artifacts_uri = ( + f"{track_parse.scheme}://{track_parse.host}:{track_parse.port}{resolved}" + ) + elif not uri_parse.host and not uri_parse.port: + resolved_artifacts_uri = ( + f"{track_parse.scheme}://{track_parse.host}:" f"{track_parse.port}{resolved}" + ) + else: + raise MlflowException( + f"The supplied artifact uri {artifact_uri} could not be resolved." + ) + + return resolved_artifacts_uri.replace("///", "/").rstrip("/") diff --git a/tests/store/artifact/test_mlflow_artifact_repo.py b/tests/store/artifact/test_mlflow_artifact_repo.py new file mode 100644 index 0000000000000..2e57bfe961407 --- /dev/null +++ b/tests/store/artifact/test_mlflow_artifact_repo.py @@ -0,0 +1,282 @@ +import os +from unittest import mock +import posixpath +import pytest + +from mlflow.store.artifact.mlflow_artifacts_repo import MlflowArtifactsRepository +from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository +from mlflow.exceptions import MlflowException + + +@pytest.fixture(scope="module", autouse=True) +def set_tracking_uri(): + with mock.patch( + "mlflow.store.artifact.mlflow_artifacts_repo.get_tracking_uri", + return_value="http://localhost:5000/", + ): + yield + + +def test_artifact_uri_factory(): + repo = get_artifact_repository("mlflow-artifacts://test.com") + assert isinstance(repo, MlflowArtifactsRepository) + + +def test_mlflow_artifact_uri_formats_resolved(): + base_url = "/api/2.0/mlflow-artifacts/artifacts" + base_path = "/my/artifact/path" + conditions = [ + ( + f"mlflow-artifacts://myhostname:4242{base_path}/hostport", + f"http://myhostname:4242{base_url}{base_path}/hostport", + ), + ( + f"mlflow-artifacts://myhostname{base_path}/host", + f"http://myhostname{base_url}{base_path}/host", + ), + ( + f"mlflow-artifacts:{base_path}/nohost", + f"http://localhost:5000{base_url}{base_path}/nohost", + ), + ( + f"mlflow-artifacts://{base_path}/redundant", + f"http://localhost:5000{base_url}{base_path}/redundant", + ), + ( + "mlflow-artifacts:/", + f"http://localhost:5000{base_url}", + ), + ] + failing_conditions = [f"mlflow-artifacts://5000/{base_path}", "mlflow-artifacts://5000/"] + + for submit, resolved in conditions: + artifact_repo = MlflowArtifactsRepository(submit) + assert artifact_repo.resolve_uri(submit) == resolved + for failing_condition in failing_conditions: + with pytest.raises( + MlflowException, + match="The mlflow-artifacts uri was supplied with a port number: 5000, but no " + "host was defined.", + ): + MlflowArtifactsRepository(failing_condition) + + +class MockResponse: + def __init__(self, data, status_code): + self.data = data + self.status_code = status_code + + def json(self): + return self.data + + def raise_for_status(self): + if self.status_code >= 400: + raise Exception("request failed") + + +class MockStreamResponse(MockResponse): + def iter_content(self, chunk_size): # pylint: disable=unused-argument + yield self.data.encode("utf-8") + + def __enter__(self): + return self + + def __exit__(self, *exc): + pass + + +class FileObjectMatcher: + def __init__(self, name, mode): + self.name = name + self.mode = mode + + def __eq__(self, other): + return self.name == other.name and self.mode == other.mode + + +@pytest.fixture +def mlflow_artifact_repo(): + artifact_uri = "mlflow-artifacts:/api/2.0/mlflow-artifacts/artifacts" + return MlflowArtifactsRepository(artifact_uri) + + +@pytest.fixture +def mlflow_artifact_repo_with_host(): + artifact_uri = "mlflow-artifacts://test.com:5000/api/2.0/mlflow-artifacts/artifacts" + return MlflowArtifactsRepository(artifact_uri) + + +@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"]) +def test_log_artifact(mlflow_artifact_repo, tmpdir, artifact_path): + tmp_path = tmpdir.join("a.txt") + tmp_path.write("0") + with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put: + mlflow_artifact_repo.log_artifact(tmp_path, artifact_path) + paths = (artifact_path,) if artifact_path else () + expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, *paths, tmp_path.basename) + mock_put.assert_called_once_with( + expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY + ) + + with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put: + with pytest.raises(Exception, match="request failed"): + mlflow_artifact_repo.log_artifact(tmp_path, artifact_path) + + +@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"]) +def test_log_artifact_with_host_and_port(mlflow_artifact_repo_with_host, tmpdir, artifact_path): + tmp_path = tmpdir.join("a.txt") + tmp_path.write("0") + with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put: + mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path) + paths = (artifact_path,) if artifact_path else () + expected_url = posixpath.join( + mlflow_artifact_repo_with_host.artifact_uri, *paths, tmp_path.basename + ) + mock_put.assert_called_once_with( + expected_url, data=FileObjectMatcher(tmp_path, "rb"), timeout=mock.ANY + ) + + with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put: + with pytest.raises(Exception, match="request failed"): + mlflow_artifact_repo_with_host.log_artifact(tmp_path, artifact_path) + + +@pytest.mark.parametrize("artifact_path", [None, "dir", "path/to/artifacts/storage"]) +def test_log_artifacts(mlflow_artifact_repo, tmpdir, artifact_path): + tmp_path_a = tmpdir.join("a.txt") + tmp_path_b = tmpdir.mkdir("dir").join("b.txt") + tmp_path_a.write("0") + tmp_path_b.write("1") + + with mock.patch("requests.Session.put", return_value=MockResponse({}, 200)) as mock_put: + mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path) + paths = (artifact_path,) if artifact_path else () + expected_url_1 = posixpath.join( + mlflow_artifact_repo.artifact_uri, *paths, tmp_path_a.basename + ) + expected_url_2 = posixpath.join( + mlflow_artifact_repo.artifact_uri, *paths, "dir", tmp_path_b.basename + ) + calls = [(args[0], kwargs["data"]) for args, kwargs in mock_put.call_args_list] + assert calls == [ + (expected_url_1, FileObjectMatcher(tmp_path_a, "rb")), + (expected_url_2, FileObjectMatcher(tmp_path_b, "rb")), + ] + + with mock.patch("requests.Session.put", return_value=MockResponse({}, 400)) as mock_put: + with pytest.raises(Exception, match="request failed"): + mlflow_artifact_repo.log_artifacts(tmpdir, artifact_path) + + +def test_list_artifacts(mlflow_artifact_repo): + with mock.patch("requests.Session.get", return_value=MockResponse({}, 200)) as mock_get: + assert mlflow_artifact_repo.list_artifacts() == [] + mock_get.assert_called_once_with( + mlflow_artifact_repo.artifact_uri, params={"path": ""}, timeout=mock.ANY + ) + + with mock.patch( + "requests.Session.get", + return_value=MockResponse( + { + "files": [ + {"path": "1.txt", "is_dir": False, "file_size": 1}, + {"path": "dir", "is_dir": True}, + ] + }, + 200, + ), + ) as mock_get: + assert [a.path for a in mlflow_artifact_repo.list_artifacts()] == ["1.txt", "dir"] + + with mock.patch( + "requests.Session.get", + return_value=MockResponse( + { + "files": [ + {"path": "1.txt", "is_dir": False, "file_size": 1}, + {"path": "dir", "is_dir": True}, + ] + }, + 200, + ), + ) as mock_get: + assert [a.path for a in mlflow_artifact_repo.list_artifacts(path="path")] == [ + "path/1.txt", + "path/dir", + ] + + with mock.patch("requests.Session.get", return_value=MockResponse({}, 400)) as mock_get: + with pytest.raises(Exception, match="request failed"): + mlflow_artifact_repo.list_artifacts() + + +def read_file(path): + with open(path) as f: + return f.read() + + +@pytest.mark.parametrize("remote_file_path", ["a.txt", "dir/b.xtx"]) +def test_download_file(mlflow_artifact_repo, tmpdir, remote_file_path): + with mock.patch( + "requests.Session.get", return_value=MockStreamResponse("data", 200) + ) as mock_get: + tmp_path = tmpdir.join(posixpath.basename(remote_file_path)) + mlflow_artifact_repo._download_file(remote_file_path, tmp_path) + expected_url = posixpath.join(mlflow_artifact_repo.artifact_uri, remote_file_path) + mock_get.assert_called_once_with(expected_url, stream=True, timeout=mock.ANY) + with open(tmp_path) as f: + assert f.read() == "data" + + with mock.patch( + "requests.Session.get", return_value=MockStreamResponse("data", 400) + ) as mock_get: + with pytest.raises(Exception, match="request failed"): + mlflow_artifact_repo._download_file(remote_file_path, tmp_path) + + +def test_download_artifacts(mlflow_artifact_repo, tmpdir): + # This test simulates downloading artifacts in the following structure: + # --------- + # - a.txt + # - dir + # - b.txt + # --------- + side_effect = [ + # Response for `list_experiments("")` called by `_is_directory("")` + MockResponse( + { + "files": [ + {"path": "a.txt", "is_dir": False, "file_size": 6}, + {"path": "dir", "is_dir": True}, + ] + }, + 200, + ), + # Response for `list_experiments("")` + MockResponse( + { + "files": [ + {"path": "a.txt", "is_dir": False, "file_size": 6}, + {"path": "dir", "is_dir": True}, + ] + }, + 200, + ), + # Response for `_download_file("a.txt")` + MockStreamResponse("data_a", 200), + # Response for `list_experiments("dir")` + MockResponse({"files": [{"path": "b.txt", "is_dir": False, "file_size": 1}]}, 200), + # Response for `_download_file("dir/b.txt")` + MockStreamResponse("data_b", 200), + ] + with mock.patch("requests.Session.get", side_effect=side_effect): + mlflow_artifact_repo.download_artifacts("", tmpdir) + paths = [os.path.join(root, f) for root, _, files in os.walk(tmpdir) for f in files] + assert [os.path.relpath(p, tmpdir) for p in paths] == [ + "a.txt", + os.path.join("dir", "b.txt"), + ] + assert read_file(paths[0]) == "data_a" + assert read_file(paths[1]) == "data_b"