Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching boto client to improve artifact download speed #4695

Merged
merged 12 commits into from Dec 2, 2021
46 changes: 31 additions & 15 deletions mlflow/store/artifact/s3_artifact_repo.py
@@ -1,3 +1,4 @@
from functools import lru_cache
import os
from mimetypes import guess_type

Expand All @@ -11,6 +12,34 @@
from mlflow.utils.file_utils import relative_path_to_artifact_path


@lru_cache(maxsize=1024)
def _get_boto3_client(signature_version, s3_endpoint_url, verify):
"""Returns a boto3 client, caching to avoid extra boto3 verify calls.

This method is outside of the S3ArtifactRepository as it is
agnostic and could be used by other instances.

`maxsize` set to avoid excessive memory consmption in the case
a user has dynamic endpoints (intentionally or as a bug)."""

import boto3
from botocore.client import Config

# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED

return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)


class S3ArtifactRepository(ArtifactRepository):
"""Stores artifacts on Amazon S3."""

Expand All @@ -36,9 +65,6 @@ def get_s3_file_upload_extra_args():
return None

def _get_s3_client(self):
import boto3
from botocore.client import Config

s3_endpoint_url = os.environ.get("MLFLOW_S3_ENDPOINT_URL")
ignore_tls = os.environ.get("MLFLOW_S3_IGNORE_TLS")

Expand All @@ -54,18 +80,8 @@ def _get_s3_client(self):
# NOTE: If you need to specify this env variable, please file an issue at
# https://github.com/mlflow/mlflow/issues so we know your use-case!
signature_version = os.environ.get("MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION", "s3v4")
# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED
return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)

return _get_boto3_client(signature_version, s3_endpoint_url, verify)

def _upload_file(self, s3_client, local_file, bucket, key):
extra_args = dict()
Expand Down