From fb2972fd0d4e6eb9c8d6050d75a6f6c1c56427a5 Mon Sep 17 00:00:00 2001 From: Ben Wilson <39283302+BenWilson2@users.noreply.github.com> Date: Tue, 30 Nov 2021 20:44:13 -0500 Subject: [PATCH] Add server option for serving only artifacts and proxied serving mode (#5045) * Add --serve-artifacts-opt and --artifacts-only options to mlflow server Signed-off-by: Ben Wilson --- examples/mlflow_artifacts/README.md | 3 + examples/mlflow_artifacts/docker-compose.yml | 2 + examples/mlflow_artifacts/example.py | 4 + mlflow/azure/client.py | 4 +- mlflow/cli.py | 77 +++++++++------ mlflow/projects/utils.py | 4 +- mlflow/server/__init__.py | 8 ++ mlflow/server/handlers.py | 93 +++++++++++++++++++ .../artifact/databricks_artifact_repo.py | 5 +- mlflow/store/artifact/http_artifact_repo.py | 7 +- .../store/artifact/mlflow_artifacts_repo.py | 9 +- mlflow/store/tracking/__init__.py | 6 ++ mlflow/utils/cli_args.py | 12 +++ mlflow/utils/file_utils.py | 4 +- mlflow/utils/rest_utils.py | 12 +++ mlflow/utils/uri.py | 22 +++++ .../artifact/test_mlflow_artifact_repo.py | 59 ++++++------ tests/test_cli.py | 61 +++++++++++- tests/tracking/test_mlflow_artifacts.py | 6 +- 19 files changed, 326 insertions(+), 72 deletions(-) diff --git a/examples/mlflow_artifacts/README.md b/examples/mlflow_artifacts/README.md index 596446249b2b7..39105c7f1edf3 100644 --- a/examples/mlflow_artifacts/README.md +++ b/examples/mlflow_artifacts/README.md @@ -16,6 +16,7 @@ First, launch the tracking server with the artifacts service via `mlflow server` ```sh # Launch a tracking server with the artifacts service $ mlflow server \ + --serve-artifacts \ --artifacts-destination ./mlartifacts \ --default-artifact-root http://localhost:5000/api/2.0/mlflow-artifacts/artifacts/experiments \ --gunicorn-opts "--log-level debug" @@ -23,9 +24,11 @@ $ mlflow server \ Notes: +- `--serve-artifacts` enables the MLflow Artifacts service endpoints to enable proxied serving of artifacts through the REST API - `--artifacts-destination` specifies the base artifact location from which to resolve artifact upload/download/list requests. In this examples, we're using a local directory `./mlartifacts`, but it can be changed to a s3 bucket or - `--default-artifact-root` points to the `experiments` directory of the artifacts service. Therefore, the default artifact location of a newly-created experiment is set to `./mlartifacts/experiments/`. - `--gunicorn-opts "--log-level debug"` is specified to print out request logs but can be omitted if unnecessary. +- `--artifacts-only` disables all other endpoints for the tracking server apart from those involved in listing, uploading, and downloading artifacts. This makes the MLflow server a single-purpose proxy for artifact handling only. Then, run `example.py` that performs upload, download, and list operations for artifacts: diff --git a/examples/mlflow_artifacts/docker-compose.yml b/examples/mlflow_artifacts/docker-compose.yml index db951d3e1503e..a700be6107410 100644 --- a/examples/mlflow_artifacts/docker-compose.yml +++ b/examples/mlflow_artifacts/docker-compose.yml @@ -54,6 +54,8 @@ services: --port 5500 --artifacts-destination s3://bucket --gunicorn-opts "--log-level debug" + --serve-artifacts + --artifacts-only postgres: image: postgres diff --git a/examples/mlflow_artifacts/example.py b/examples/mlflow_artifacts/example.py index 16dc8f31e830a..8e9032875793a 100644 --- a/examples/mlflow_artifacts/example.py +++ b/examples/mlflow_artifacts/example.py @@ -10,6 +10,10 @@ def save_text(path, text): f.write(text) +# NOTE: ensure the tracking server has been started with --serve-artifacts to enable +# MLflow artifact serving functionality. + + def main(): assert "MLFLOW_TRACKING_URI" in os.environ diff --git a/mlflow/azure/client.py b/mlflow/azure/client.py index 81f5ffe49f3ba..e4c94077affed 100644 --- a/mlflow/azure/client.py +++ b/mlflow/azure/client.py @@ -38,7 +38,7 @@ def put_block(sas_url, block_id, data, headers): with rest_utils.cloud_storage_http_request( "put", request_url, data=data, headers=request_headers ) as response: - response.raise_for_status() + rest_utils.augmented_raise_for_status(response) def put_block_list(sas_url, block_list, headers): @@ -66,7 +66,7 @@ def put_block_list(sas_url, block_list, headers): with rest_utils.cloud_storage_http_request( "put", request_url, data=data, headers=request_headers ) as response: - response.raise_for_status() + rest_utils.augmented_raise_for_status(response) def _append_query_parameters(url, parameters): diff --git a/mlflow/cli.py b/mlflow/cli.py index d91793629a0af..55154ddda520b 100644 --- a/mlflow/cli.py +++ b/mlflow/cli.py @@ -13,14 +13,14 @@ import mlflow.runs import mlflow.store.artifact.cli from mlflow import tracking -from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH +from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH, DEFAULT_ARTIFACTS_URI from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.tracking import _get_store from mlflow.utils import cli_args from mlflow.utils.annotations import experimental from mlflow.utils.logging_utils import eprint from mlflow.utils.process import ShellCommandException -from mlflow.utils.uri import is_local_uri +from mlflow.utils.uri import resolve_default_artifact_root from mlflow.entities.lifecycle_stage import LifecycleStage from mlflow.exceptions import MlflowException @@ -233,20 +233,27 @@ def _validate_server_args(gunicorn_opts=None, workers=None, waitress_opts=None): "SQLAlchemy-compatible database connection strings " "(e.g. 'sqlite:///path/to/file.db') or local filesystem URIs " "(e.g. 'file:///absolute/path/to/directory'). By default, data will be logged " - "to the ./mlruns directory.", + f"to {DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH}", ) @click.option( "--default-artifact-root", metavar="URI", default=None, - help="Path to local directory to store artifacts, for new experiments. " - "Note that this flag does not impact already-created experiments. " - "Default: " + DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH, + help="Directory in which to store artifacts for any new experiments created. For tracking " + "server backends that rely on SQL, this option is required in order to store artifacts. " + "Note that this flag does not impact already-created experiments with any previous " + "configuration of an MLflow server instance. " + "If the --serve-artifacts option is specified, the default artifact root is " + f"{DEFAULT_ARTIFACTS_URI}. Otherwise, the default artifact root is " + f"{DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH}.", ) +@cli_args.SERVE_ARTIFACTS @cli_args.ARTIFACTS_DESTINATION @cli_args.PORT @cli_args.HOST -def ui(backend_store_uri, default_artifact_root, artifacts_destination, port, host): +def ui( + backend_store_uri, default_artifact_root, serve_artifacts, artifacts_destination, port, host +): """ Launch the MLflow tracking UI for local viewing of run results. To launch a production server, use the "mlflow server" command instead. @@ -263,11 +270,9 @@ def ui(backend_store_uri, default_artifact_root, artifacts_destination, port, ho if not backend_store_uri: backend_store_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - if not default_artifact_root: - if is_local_uri(backend_store_uri): - default_artifact_root = backend_store_uri - else: - default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + default_artifact_root = resolve_default_artifact_root( + serve_artifacts, default_artifact_root, backend_store_uri, resolve_to_local=True + ) try: initialize_backend_stores(backend_store_uri, default_artifact_root) @@ -279,7 +284,15 @@ def ui(backend_store_uri, default_artifact_root, artifacts_destination, port, ho # TODO: We eventually want to disable the write path in this version of the server. try: _run_server( - backend_store_uri, default_artifact_root, artifacts_destination, host, port, None, 1 + backend_store_uri, + default_artifact_root, + serve_artifacts, + False, + artifacts_destination, + host, + port, + None, + 1, ) except ShellCommandException: eprint("Running the mlflow server failed. Please see the logs above for details.") @@ -315,10 +328,24 @@ def _validate_static_prefix(ctx, param, value): # pylint: disable=unused-argume "--default-artifact-root", metavar="URI", default=None, - help="Local or S3 URI to store artifacts, for new experiments. " - "Note that this flag does not impact already-created experiments. " - "Default: Within file store, if a file:/ URI is provided. If a sql backend is" - " used, then this option is required.", + help="Directory in which to store artifacts for any new experiments created. For tracking " + "server backends that rely on SQL, this option is required in order to store artifacts. " + "Note that this flag does not impact already-created experiments with any previous " + "configuration of an MLflow server instance. " + f"By default, data will be logged to the {DEFAULT_ARTIFACTS_URI} uri proxy if " + "the --serve-artifacts option is enabled. Otherwise, the default location will " + f"be {DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH}.", +) +@cli_args.SERVE_ARTIFACTS +@click.option( + "--artifacts-only", + is_flag=True, + default=False, + help="If specified, configures the mlflow server to be used only for proxied artifact serving. " + "With this mode enabled, functionality of the mlflow tracking service (e.g. run creation, " + "metric logging, and parameter logging) is disabled. The server will only expose " + "endpoints for uploading, downloading, and listing artifacts. " + "Default: False", ) @cli_args.ARTIFACTS_DESTINATION @cli_args.HOST @@ -348,6 +375,8 @@ def _validate_static_prefix(ctx, param, value): # pylint: disable=unused-argume def server( backend_store_uri, default_artifact_root, + serve_artifacts, + artifacts_only, artifacts_destination, host, port, @@ -374,15 +403,9 @@ def server( if not backend_store_uri: backend_store_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH - if not default_artifact_root: - if is_local_uri(backend_store_uri): - default_artifact_root = backend_store_uri - else: - eprint( - "Option 'default-artifact-root' is required, when backend store is not " - "local file based." - ) - sys.exit(1) + default_artifact_root = resolve_default_artifact_root( + serve_artifacts, default_artifact_root, backend_store_uri + ) try: initialize_backend_stores(backend_store_uri, default_artifact_root) @@ -395,6 +418,8 @@ def server( _run_server( backend_store_uri, default_artifact_root, + serve_artifacts, + artifacts_only, artifacts_destination, host, port, diff --git a/mlflow/projects/utils.py b/mlflow/projects/utils.py index fc1bfd7943e99..1b2224e523f2c 100644 --- a/mlflow/projects/utils.py +++ b/mlflow/projects/utils.py @@ -28,7 +28,7 @@ MLFLOW_PROJECT_ENTRY_POINT, MLFLOW_PARENT_RUN_ID, ) - +from mlflow.utils.rest_utils import augmented_raise_for_status # TODO: this should be restricted to just Git repos and not S3 and stuff like that _GIT_URI_REGEX = re.compile(r"^[^/]*:") @@ -214,7 +214,7 @@ def _fetch_zip_repo(uri): # https://github.com/mlflow/mlflow/issues/763. response = requests.get(uri) try: - response.raise_for_status() + augmented_raise_for_status(response) except requests.HTTPError as error: raise ExecutionException("Unable to retrieve ZIP file. Reason: %s" % str(error)) return BytesIO(response.content) diff --git a/mlflow/server/__init__.py b/mlflow/server/__init__.py index deecd40f12e46..819ed18e5313f 100644 --- a/mlflow/server/__init__.py +++ b/mlflow/server/__init__.py @@ -20,6 +20,8 @@ ARTIFACT_ROOT_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_ROOT" ARTIFACTS_DESTINATION_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_DESTINATION" PROMETHEUS_EXPORTER_ENV_VAR = "prometheus_multiproc_dir" +SERVE_ARTIFACTS_ENV_VAR = "_MLFLOW_SERVER_SERVE_ARTIFACTS" +ARTIFACTS_ONLY_ENV_VAR = "_MLFLOW_SERVER_ARTIFACTS_ONLY" REL_STATIC_DIR = "js/build" @@ -106,6 +108,8 @@ def _build_gunicorn_command(gunicorn_opts, host, port, workers): def _run_server( file_store_path, default_artifact_root, + serve_artifacts, + artifacts_only, artifacts_destination, host, port, @@ -126,6 +130,10 @@ def _run_server( env_map[BACKEND_STORE_URI_ENV_VAR] = file_store_path if default_artifact_root: env_map[ARTIFACT_ROOT_ENV_VAR] = default_artifact_root + if serve_artifacts: + env_map[SERVE_ARTIFACTS_ENV_VAR] = "true" + if artifacts_only: + env_map[ARTIFACTS_ONLY_ENV_VAR] = "true" if artifacts_destination: env_map[ARTIFACTS_DESTINATION_ENV_VAR] = artifacts_destination if static_prefix: diff --git a/mlflow/server/handlers.py b/mlflow/server/handlers.py index e3a3c2fbd4add..c8b89870aca45 100644 --- a/mlflow/server/handlers.py +++ b/mlflow/server/handlers.py @@ -262,6 +262,44 @@ def wrapper(*args, **kwargs): ] +def _disable_unless_serve_artifacts(func): + @wraps(func) + def wrapper(*args, **kwargs): + from mlflow.server import SERVE_ARTIFACTS_ENV_VAR + + if not os.environ.get(SERVE_ARTIFACTS_ENV_VAR): + return Response( + ( + f"Endpoint: {request.url_rule} disabled due to the mlflow server running " + "without `--serve-artifacts`. To enable artifacts server functionality, " + "run `mlflow server` with `--serve-artifacts`" + ), + 503, + ) + return func(*args, **kwargs) + + return wrapper + + +def _disable_if_artifacts_only(func): + @wraps(func) + def wrapper(*args, **kwargs): + from mlflow.server import ARTIFACTS_ONLY_ENV_VAR + + if os.environ.get(ARTIFACTS_ONLY_ENV_VAR): + return Response( + ( + f"Endpoint: {request.url_rule} disabled due to the mlflow server running " + "in `--artifacts-only` mode. To enable tracking server functionality, run " + "`mlflow server` without `--artifacts-only`" + ), + 503, + ) + return func(*args, **kwargs) + + return wrapper + + @catch_mlflow_exception def get_artifact_handler(): from querystring_parser import parser @@ -279,7 +317,11 @@ def _not_implemented(): return response +# Tracking Server APIs + + @catch_mlflow_exception +@_disable_if_artifacts_only def _create_experiment(): request_message = _get_request_message(CreateExperiment()) tags = [ExperimentTag(tag.key, tag.value) for tag in request_message.tags] @@ -294,6 +336,7 @@ def _create_experiment(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_experiment(): request_message = _get_request_message(GetExperiment()) response_message = GetExperiment.Response() @@ -305,6 +348,7 @@ def _get_experiment(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_experiment_by_name(): request_message = _get_request_message(GetExperimentByName()) response_message = GetExperimentByName.Response() @@ -322,6 +366,7 @@ def _get_experiment_by_name(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_experiment(): request_message = _get_request_message(DeleteExperiment()) _get_tracking_store().delete_experiment(request_message.experiment_id) @@ -332,6 +377,7 @@ def _delete_experiment(): @catch_mlflow_exception +@_disable_if_artifacts_only def _restore_experiment(): request_message = _get_request_message(RestoreExperiment()) _get_tracking_store().restore_experiment(request_message.experiment_id) @@ -342,6 +388,7 @@ def _restore_experiment(): @catch_mlflow_exception +@_disable_if_artifacts_only def _update_experiment(): request_message = _get_request_message(UpdateExperiment()) if request_message.new_name: @@ -355,6 +402,7 @@ def _update_experiment(): @catch_mlflow_exception +@_disable_if_artifacts_only def _create_run(): request_message = _get_request_message(CreateRun()) @@ -374,6 +422,7 @@ def _create_run(): @catch_mlflow_exception +@_disable_if_artifacts_only def _update_run(): request_message = _get_request_message(UpdateRun()) run_id = request_message.run_id or request_message.run_uuid @@ -387,6 +436,7 @@ def _update_run(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_run(): request_message = _get_request_message(DeleteRun()) _get_tracking_store().delete_run(request_message.run_id) @@ -397,6 +447,7 @@ def _delete_run(): @catch_mlflow_exception +@_disable_if_artifacts_only def _restore_run(): request_message = _get_request_message(RestoreRun()) _get_tracking_store().restore_run(request_message.run_id) @@ -407,6 +458,7 @@ def _restore_run(): @catch_mlflow_exception +@_disable_if_artifacts_only def _log_metric(): request_message = _get_request_message(LogMetric()) metric = Metric( @@ -421,6 +473,7 @@ def _log_metric(): @catch_mlflow_exception +@_disable_if_artifacts_only def _log_param(): request_message = _get_request_message(LogParam()) param = Param(request_message.key, request_message.value) @@ -433,6 +486,7 @@ def _log_param(): @catch_mlflow_exception +@_disable_if_artifacts_only def _set_experiment_tag(): request_message = _get_request_message(SetExperimentTag()) tag = ExperimentTag(request_message.key, request_message.value) @@ -444,6 +498,7 @@ def _set_experiment_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _set_tag(): request_message = _get_request_message(SetTag()) tag = RunTag(request_message.key, request_message.value) @@ -456,6 +511,7 @@ def _set_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_tag(): request_message = _get_request_message(DeleteTag()) _get_tracking_store().delete_tag(request_message.run_id, request_message.key) @@ -466,6 +522,7 @@ def _delete_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_run(): request_message = _get_request_message(GetRun()) response_message = GetRun.Response() @@ -477,6 +534,7 @@ def _get_run(): @catch_mlflow_exception +@_disable_if_artifacts_only def _search_runs(): request_message = _get_request_message(SearchRuns()) response_message = SearchRuns.Response() @@ -500,6 +558,7 @@ def _search_runs(): @catch_mlflow_exception +@_disable_if_artifacts_only def _list_artifacts(): request_message = _get_request_message(ListArtifacts()) response_message = ListArtifacts.Response() @@ -518,6 +577,7 @@ def _list_artifacts(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_metric_history(): request_message = _get_request_message(GetMetricHistory()) response_message = GetMetricHistory.Response() @@ -530,6 +590,7 @@ def _get_metric_history(): @catch_mlflow_exception +@_disable_if_artifacts_only def _list_experiments(): request_message = _get_request_message(ListExperiments()) # `ListFields` returns a list of (FieldDescriptor, value) tuples for *present* fields: @@ -547,11 +608,13 @@ def _list_experiments(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_artifact_repo(run): return get_artifact_repository(run.info.artifact_uri) @catch_mlflow_exception +@_disable_if_artifacts_only def _log_batch(): _validate_batch_log_api_req(_get_request_json()) request_message = _get_request_message(LogBatch()) @@ -568,6 +631,7 @@ def _log_batch(): @catch_mlflow_exception +@_disable_if_artifacts_only def _log_model(): request_message = _get_request_message(LogModel()) try: @@ -603,7 +667,11 @@ def _wrap_response(response_message): return response +# Model Registry APIs + + @catch_mlflow_exception +@_disable_if_artifacts_only def _create_registered_model(): request_message = _get_request_message(CreateRegisteredModel()) registered_model = _get_model_registry_store().create_registered_model( @@ -616,6 +684,7 @@ def _create_registered_model(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_registered_model(): request_message = _get_request_message(GetRegisteredModel()) registered_model = _get_model_registry_store().get_registered_model(name=request_message.name) @@ -624,6 +693,7 @@ def _get_registered_model(): @catch_mlflow_exception +@_disable_if_artifacts_only def _update_registered_model(): request_message = _get_request_message(UpdateRegisteredModel()) name = request_message.name @@ -636,6 +706,7 @@ def _update_registered_model(): @catch_mlflow_exception +@_disable_if_artifacts_only def _rename_registered_model(): request_message = _get_request_message(RenameRegisteredModel()) name = request_message.name @@ -648,6 +719,7 @@ def _rename_registered_model(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_registered_model(): request_message = _get_request_message(DeleteRegisteredModel()) _get_model_registry_store().delete_registered_model(name=request_message.name) @@ -655,6 +727,7 @@ def _delete_registered_model(): @catch_mlflow_exception +@_disable_if_artifacts_only def _list_registered_models(): request_message = _get_request_message(ListRegisteredModels()) registered_models = _get_model_registry_store().list_registered_models( @@ -668,6 +741,7 @@ def _list_registered_models(): @catch_mlflow_exception +@_disable_if_artifacts_only def _search_registered_models(): request_message = _get_request_message(SearchRegisteredModels()) store = _get_model_registry_store() @@ -685,6 +759,7 @@ def _search_registered_models(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_latest_versions(): request_message = _get_request_message(GetLatestVersions()) latest_versions = _get_model_registry_store().get_latest_versions( @@ -696,6 +771,7 @@ def _get_latest_versions(): @catch_mlflow_exception +@_disable_if_artifacts_only def _set_registered_model_tag(): request_message = _get_request_message(SetRegisteredModelTag()) tag = RegisteredModelTag(key=request_message.key, value=request_message.value) @@ -704,6 +780,7 @@ def _set_registered_model_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_registered_model_tag(): request_message = _get_request_message(DeleteRegisteredModelTag()) _get_model_registry_store().delete_registered_model_tag( @@ -713,6 +790,7 @@ def _delete_registered_model_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _create_model_version(): request_message = _get_request_message(CreateModelVersion()) model_version = _get_model_registry_store().create_model_version( @@ -728,6 +806,7 @@ def _create_model_version(): @catch_mlflow_exception +@_disable_if_artifacts_only def get_model_version_artifact_handler(): from querystring_parser import parser @@ -740,6 +819,7 @@ def get_model_version_artifact_handler(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_model_version(): request_message = _get_request_message(GetModelVersion()) model_version = _get_model_registry_store().get_model_version( @@ -751,6 +831,7 @@ def _get_model_version(): @catch_mlflow_exception +@_disable_if_artifacts_only def _update_model_version(): request_message = _get_request_message(UpdateModelVersion()) new_description = None @@ -763,6 +844,7 @@ def _update_model_version(): @catch_mlflow_exception +@_disable_if_artifacts_only def _transition_stage(): request_message = _get_request_message(TransitionModelVersionStage()) model_version = _get_model_registry_store().transition_model_version_stage( @@ -777,6 +859,7 @@ def _transition_stage(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_model_version(): request_message = _get_request_message(DeleteModelVersion()) _get_model_registry_store().delete_model_version( @@ -786,6 +869,7 @@ def _delete_model_version(): @catch_mlflow_exception +@_disable_if_artifacts_only def _get_model_version_download_uri(): request_message = _get_request_message(GetModelVersionDownloadUri()) download_uri = _get_model_registry_store().get_model_version_download_uri( @@ -796,6 +880,7 @@ def _get_model_version_download_uri(): @catch_mlflow_exception +@_disable_if_artifacts_only def _search_model_versions(): request_message = _get_request_message(SearchModelVersions()) model_versions = _get_model_registry_store().search_model_versions(request_message.filter) @@ -805,6 +890,7 @@ def _search_model_versions(): @catch_mlflow_exception +@_disable_if_artifacts_only def _set_model_version_tag(): request_message = _get_request_message(SetModelVersionTag()) tag = ModelVersionTag(key=request_message.key, value=request_message.value) @@ -815,6 +901,7 @@ def _set_model_version_tag(): @catch_mlflow_exception +@_disable_if_artifacts_only def _delete_model_version_tag(): request_message = _get_request_message(DeleteModelVersionTag()) _get_model_registry_store().delete_model_version_tag( @@ -823,7 +910,11 @@ def _delete_model_version_tag(): return _wrap_response(DeleteModelVersionTag.Response()) +# MLflow Artifacts APIs + + @catch_mlflow_exception +@_disable_unless_serve_artifacts def _download_artifact(artifact_path): """ A request handler for `GET /mlflow-artifacts/artifacts/` to download an artifact @@ -850,6 +941,7 @@ def stream_and_remove_file(): @catch_mlflow_exception +@_disable_unless_serve_artifacts def _upload_artifact(artifact_path): """ A request handler for `PUT /mlflow-artifacts/artifacts/` to upload an artifact @@ -873,6 +965,7 @@ def _upload_artifact(artifact_path): @catch_mlflow_exception +@_disable_unless_serve_artifacts def _list_artifacts_mlflow_artifacts(): """ A request handler for `GET /mlflow-artifacts/artifacts?path=` to list artifacts in `path` diff --git a/mlflow/store/artifact/databricks_artifact_repo.py b/mlflow/store/artifact/databricks_artifact_repo.py index 8ddd9df401263..d3a85b3d5d18b 100644 --- a/mlflow/store/artifact/databricks_artifact_repo.py +++ b/mlflow/store/artifact/databricks_artifact_repo.py @@ -39,6 +39,7 @@ call_endpoint, extract_api_info_for_service, _REST_API_PATH_PREFIX, + augmented_raise_for_status, ) from mlflow.utils.uri import ( extract_and_normalize_path, @@ -247,13 +248,13 @@ def _signed_url_upload_file(self, credentials, local_file): with rest_utils.cloud_storage_http_request( "put", signed_write_uri, data="", headers=headers ) as response: - response.raise_for_status() + augmented_raise_for_status(response) else: with open(local_file, "rb") as file: with rest_utils.cloud_storage_http_request( "put", signed_write_uri, data=file, headers=headers ) as response: - response.raise_for_status() + augmented_raise_for_status(response) except Exception as err: raise MlflowException(err) diff --git a/mlflow/store/artifact/http_artifact_repo.py b/mlflow/store/artifact/http_artifact_repo.py index f2dc2b4036ac5..2bf41b664f7f3 100644 --- a/mlflow/store/artifact/http_artifact_repo.py +++ b/mlflow/store/artifact/http_artifact_repo.py @@ -5,6 +5,7 @@ from mlflow.entities import FileInfo from mlflow.store.artifact.artifact_repo import ArtifactRepository, verify_artifact_path from mlflow.utils.file_utils import relative_path_to_artifact_path +from mlflow.utils.rest_utils import augmented_raise_for_status class HttpArtifactRepository(ArtifactRepository): @@ -26,7 +27,7 @@ def log_artifact(self, local_file, artifact_path=None): url = posixpath.join(self.artifact_uri, *paths) with open(local_file, "rb") as f: resp = self._session.put(url, data=f, timeout=600) - resp.raise_for_status() + augmented_raise_for_status(resp) def log_artifacts(self, local_dir, artifact_path=None): local_dir = os.path.abspath(local_dir) @@ -49,7 +50,7 @@ def list_artifacts(self, path=None): root = tail.lstrip("/") params = {"path": posixpath.join(root, path) if path else root} resp = self._session.get(url, params=params, timeout=10) - resp.raise_for_status() + augmented_raise_for_status(resp) file_infos = [] for f in resp.json().get("files", []): file_info = FileInfo( @@ -64,7 +65,7 @@ def list_artifacts(self, path=None): def _download_file(self, remote_file_path, local_path): url = posixpath.join(self.artifact_uri, remote_file_path) with self._session.get(url, stream=True, timeout=10) as resp: - resp.raise_for_status() + augmented_raise_for_status(resp) with open(local_path, "wb") as f: chunk_size = 1024 * 1024 # 1 MB for chunk in resp.iter_content(chunk_size=chunk_size): diff --git a/mlflow/store/artifact/mlflow_artifacts_repo.py b/mlflow/store/artifact/mlflow_artifacts_repo.py index 1d624d5ade894..c9fd90b469b4d 100644 --- a/mlflow/store/artifact/mlflow_artifacts_repo.py +++ b/mlflow/store/artifact/mlflow_artifacts_repo.py @@ -1,5 +1,6 @@ from urllib.parse import urlparse from collections import namedtuple +import re from mlflow.store.artifact.http_artifact_repo import HttpArtifactRepository from mlflow.tracking._tracking_service.utils import get_tracking_uri @@ -49,13 +50,12 @@ class MlflowArtifactsRepository(HttpArtifactRepository): def __init__(self, artifact_uri): - super().__init__(self.resolve_uri(artifact_uri)) + super().__init__(self.resolve_uri(artifact_uri, get_tracking_uri())) @classmethod - def resolve_uri(cls, artifact_uri): + def resolve_uri(cls, artifact_uri, tracking_uri): base_url = "/api/2.0/mlflow-artifacts/artifacts" - tracking_uri = get_tracking_uri() track_parse = _parse_artifact_uri(tracking_uri) @@ -72,7 +72,8 @@ def resolve_uri(cls, artifact_uri): elif uri_parse.path == base_url: # for operations like list artifacts resolved = base_url else: - resolved = f"{base_url}{track_parse.path}{uri_parse.path.lstrip('/')}" + resolved = f"{base_url}/{track_parse.path}{uri_parse.path}" + resolved = re.sub("//+", "/", resolved) if uri_parse.host and uri_parse.port: resolved_artifacts_uri = ( diff --git a/mlflow/store/tracking/__init__.py b/mlflow/store/tracking/__init__.py index 889f5d7c43ec1..c708e39366564 100644 --- a/mlflow/store/tracking/__init__.py +++ b/mlflow/store/tracking/__init__.py @@ -10,5 +10,11 @@ # Also used as default location for artifacts, when not provided, in non local file based backends # (eg MySQL) DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH = "./mlruns" +# Used for defining the artifacts uri (`--default-artifact-root`) for the tracking server when +# configuring the server to use the option `--serve-artifacts` mode. This default can be +# overridden by specifying an override to `--default-artifact-root` for the MLflow tracking server. +# When the server is not operating in `--serve-artifacts` configuration, the default artifact +# storage location will be `DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH`. +DEFAULT_ARTIFACTS_URI = "mlflow-artifacts:/" SEARCH_MAX_RESULTS_DEFAULT = 1000 SEARCH_MAX_RESULTS_THRESHOLD = 50000 diff --git a/mlflow/utils/cli_args.py b/mlflow/utils/cli_args.py index 71ef7771e5530..30dc1fc793f8c 100644 --- a/mlflow/utils/cli_args.py +++ b/mlflow/utils/cli_args.py @@ -99,3 +99,15 @@ "artifact root location is http or mlflow-artifacts URI." ), ) + +SERVE_ARTIFACTS = click.option( + "--serve-artifacts", + is_flag=True, + default=False, + help="If specified, enables serving of artifact uploads, downloads, and list requests " + "by routing these requests to the storage location that is specified by " + "'--artifact-destination' directly through a proxy. The default location that " + "these requests are served from is a local './mlartifacts' directory which can be " + "overridden via the '--artifacts-destination' argument. " + "Default: False", +) diff --git a/mlflow/utils/file_utils.py b/mlflow/utils/file_utils.py index a5e87bedc3d47..fc16cf90704de 100644 --- a/mlflow/utils/file_utils.py +++ b/mlflow/utils/file_utils.py @@ -22,7 +22,7 @@ from mlflow.entities import FileInfo from mlflow.exceptions import MissingConfigException -from mlflow.utils.rest_utils import cloud_storage_http_request +from mlflow.utils.rest_utils import cloud_storage_http_request, augmented_raise_for_status ENCODING = "utf-8" @@ -453,7 +453,7 @@ def download_file_using_http_uri(http_uri, download_path, chunk_size=100000000): providers. """ with cloud_storage_http_request("get", http_uri, stream=True) as response: - response.raise_for_status() + augmented_raise_for_status(response) with open(download_path, "wb") as output_file: for chunk in response.iter_content(chunk_size=chunk_size): if not chunk: diff --git a/mlflow/utils/rest_utils.py b/mlflow/utils/rest_utils.py index 067868fde7fd5..f701a0e0bac89 100644 --- a/mlflow/utils/rest_utils.py +++ b/mlflow/utils/rest_utils.py @@ -6,6 +6,7 @@ from packaging.version import Version from requests.adapters import HTTPAdapter from urllib3.util import Retry +from requests.exceptions import HTTPError from mlflow import __version__ from mlflow.protos import databricks_pb2 @@ -186,6 +187,17 @@ def verify_rest_response(response, endpoint): return response +def augmented_raise_for_status(response): + """Wrap the standard `requests.response.raise_for_status()` method and return reason""" + try: + response.raise_for_status() + except HTTPError as e: + if response.text: + raise HTTPError(f"{e}. Response text: {response.text}") + else: + raise e + + def _get_path(path_prefix, endpoint_path): return "{}{}".format(path_prefix, endpoint_path) diff --git a/mlflow/utils/uri.py b/mlflow/utils/uri.py index 808585b74e7ae..55e346d1144f9 100644 --- a/mlflow/utils/uri.py +++ b/mlflow/utils/uri.py @@ -1,10 +1,13 @@ +import sys import posixpath import urllib.parse from mlflow.exceptions import MlflowException from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE from mlflow.store.db.db_types import DATABASE_ENGINES +from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH, DEFAULT_ARTIFACTS_URI from mlflow.utils.validation import _validate_db_type_string +from mlflow.utils.logging_utils import eprint _INVALID_DB_URI_MSG = ( "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " @@ -294,3 +297,22 @@ def dbfs_hdfs_uri_to_fuse_path(dbfs_uri): ) return _DBFS_FUSE_PREFIX + dbfs_uri[len(_DBFS_HDFS_URI_PREFIX) :] + + +def resolve_default_artifact_root( + serve_artifacts, default_artifact_root, backend_store_uri, resolve_to_local=False +): + if serve_artifacts and not default_artifact_root: + default_artifact_root = DEFAULT_ARTIFACTS_URI + elif not serve_artifacts and not default_artifact_root: + if is_local_uri(backend_store_uri): + default_artifact_root = backend_store_uri + elif resolve_to_local: + default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH + else: + eprint( + "Option 'default-artifact-root' is required, when backend store is not " + "local file based." + ) + sys.exit(1) + return default_artifact_root diff --git a/tests/store/artifact/test_mlflow_artifact_repo.py b/tests/store/artifact/test_mlflow_artifact_repo.py index 2e57bfe961407..62a0e14cda479 100644 --- a/tests/store/artifact/test_mlflow_artifact_repo.py +++ b/tests/store/artifact/test_mlflow_artifact_repo.py @@ -22,36 +22,39 @@ def test_artifact_uri_factory(): assert isinstance(repo, MlflowArtifactsRepository) -def test_mlflow_artifact_uri_formats_resolved(): - base_url = "/api/2.0/mlflow-artifacts/artifacts" - base_path = "/my/artifact/path" - conditions = [ - ( - f"mlflow-artifacts://myhostname:4242{base_path}/hostport", - f"http://myhostname:4242{base_url}{base_path}/hostport", - ), - ( - f"mlflow-artifacts://myhostname{base_path}/host", - f"http://myhostname{base_url}{base_path}/host", - ), - ( - f"mlflow-artifacts:{base_path}/nohost", - f"http://localhost:5000{base_url}{base_path}/nohost", - ), - ( - f"mlflow-artifacts://{base_path}/redundant", - f"http://localhost:5000{base_url}{base_path}/redundant", - ), - ( - "mlflow-artifacts:/", - f"http://localhost:5000{base_url}", - ), - ] +base_url = "/api/2.0/mlflow-artifacts/artifacts" +base_path = "/my/artifact/path" +conditions = [ + ( + f"mlflow-artifacts://myhostname:4242{base_path}/hostport", + f"http://myhostname:4242{base_url}{base_path}/hostport", + ), + ( + f"mlflow-artifacts://myhostname{base_path}/host", + f"http://myhostname{base_url}{base_path}/host", + ), + ( + f"mlflow-artifacts:{base_path}/nohost", + f"http://localhost:5000{base_url}{base_path}/nohost", + ), + ( + f"mlflow-artifacts://{base_path}/redundant", + f"http://localhost:5000{base_url}{base_path}/redundant", + ), + ("mlflow-artifacts:/", f"http://localhost:5000{base_url}"), +] + + +@pytest.mark.parametrize("tracking_uri", ["http://localhost:5000", "http://localhost:5000/"]) +@pytest.mark.parametrize("artifact_uri, resolved_uri", conditions) +def test_mlflow_artifact_uri_formats_resolved(artifact_uri, resolved_uri, tracking_uri): + + assert MlflowArtifactsRepository.resolve_uri(artifact_uri, tracking_uri) == resolved_uri + + +def test_mlflow_artifact_uri_raises_with_invalid_tracking_uri(): failing_conditions = [f"mlflow-artifacts://5000/{base_path}", "mlflow-artifacts://5000/"] - for submit, resolved in conditions: - artifact_repo = MlflowArtifactsRepository(submit) - assert artifact_repo.resolve_uri(submit) == resolved for failing_condition in failing_conditions: with pytest.raises( MlflowException, diff --git a/tests/test_cli.py b/tests/test_cli.py index 507778c6482dc..e5275d9ec8e5a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -22,6 +22,7 @@ from mlflow.store.tracking.file_store import FileStore from mlflow.exceptions import MlflowException from mlflow.entities import ViewType +from mlflow.utils.rest_utils import augmented_raise_for_status from tests.helper_functions import pyfunc_serve_and_score_model, get_safe_port from tests.tracking.integration_test_utils import _await_server_up_or_die @@ -35,7 +36,7 @@ def test_mlflow_server_command(command): try: _await_server_up_or_die(port, timeout=10) resp = requests.get(f"http://localhost:{port}/health") - resp.raise_for_status() + augmented_raise_for_status(resp) assert resp.text == "OK" finally: process.kill() @@ -58,6 +59,18 @@ def test_server_static_prefix_validation(): run_server_mock.assert_not_called() +def test_server_mlflow_artifacts_options(): + with mock.patch("mlflow.server._run_server") as run_server_mock: + CliRunner().invoke(server, ["--artifacts-only"]) + run_server_mock.assert_called_once() + with mock.patch("mlflow.server._run_server") as run_server_mock: + CliRunner().invoke(server, ["--serve-artifacts"]) + run_server_mock.assert_called_once() + with mock.patch("mlflow.server._run_server") as run_server_mock: + CliRunner().invoke(server, ["--artifacts-only", "--serve-artifacts"]) + run_server_mock.assert_called_once() + + def test_server_default_artifact_root_validation(): with mock.patch("mlflow.server._run_server") as run_server_mock: result = CliRunner().invoke(server, ["--backend-store-uri", "sqlite:///my.db"]) @@ -235,3 +248,49 @@ def predict(self, context, model_input): # pylint: disable=unused-variable assert scoring_response.status_code == 200 served_model_preds = np.array(json.loads(scoring_response.content)) np.testing.assert_array_equal(served_model_preds, model.predict(data, None)) + + +def test_mlflow_tracking_disabled_in_artifacts_only_mode(): + + port = get_safe_port() + cmd = ["mlflow", "server", "--port", str(port), "--artifacts-only"] + process = subprocess.Popen(cmd) + _await_server_up_or_die(port, timeout=10) + resp = requests.get(f"http://localhost:{port}/api/2.0/mlflow/experiments/list") + assert ( + "Endpoint: /api/2.0/mlflow/experiments/list disabled due to the mlflow server running " + "in `--artifacts-only` mode." in resp.text + ) + process.kill() + + +def test_mlflow_artifact_list_in_artifacts_only_mode(): + + port = get_safe_port() + cmd = ["mlflow", "server", "--port", str(port), "--artifacts-only", "--serve-artifacts"] + process = subprocess.Popen(cmd) + try: + _await_server_up_or_die(port, timeout=10) + resp = requests.get(f"http://localhost:{port}/api/2.0/mlflow-artifacts/artifacts") + augmented_raise_for_status(resp) + assert resp.status_code == 200 + assert resp.text == "{}" + finally: + process.kill() + + +def test_mlflow_artifact_service_unavailable_without_config(): + + port = get_safe_port() + cmd = ["mlflow", "server", "--port", str(port)] + process = subprocess.Popen(cmd) + try: + _await_server_up_or_die(port, timeout=10) + endpoint = "/api/2.0/mlflow-artifacts/artifacts" + resp = requests.get(f"http://localhost:{port}{endpoint}") + assert ( + f"Endpoint: {endpoint} disabled due to the mlflow server running without " + "`--serve-artifacts`" in resp.text + ) + finally: + process.kill() diff --git a/tests/tracking/test_mlflow_artifacts.py b/tests/tracking/test_mlflow_artifacts.py index f9194ace771f7..96beefcda33d0 100644 --- a/tests/tracking/test_mlflow_artifacts.py +++ b/tests/tracking/test_mlflow_artifacts.py @@ -3,7 +3,7 @@ import subprocess import tempfile import requests - +import pathlib import pytest import mlflow @@ -26,6 +26,7 @@ def _launch_server(host, port, backend_store_uri, default_artifact_root, artifac str(port), "--backend-store-uri", backend_store_uri, + "--serve-artifacts", "--default-artifact-root", default_artifact_root, "--artifacts-destination", @@ -227,6 +228,7 @@ def is_github_actions(): @pytest.mark.skipif(is_windows(), reason="This example doesn't work on Windows") def test_mlflow_artifacts_example(tmpdir): + root = pathlib.Path(mlflow.__file__).parents[1] # On GitHub Actions, remove generated images to save disk space rmi_option = "--rmi all" if is_github_actions() else "" cmd = f""" @@ -241,5 +243,5 @@ def test_mlflow_artifacts_example(tmpdir): subprocess.run( ["bash", script_path.strpath], check=True, - cwd=os.path.join(os.getcwd(), "examples", "mlflow_artifacts"), + cwd=os.path.join(root, "examples", "mlflow_artifacts"), )