Skip to content

Commit

Permalink
Caching boto client
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Hinton committed Aug 12, 2021
1 parent 306c96a commit 943eb8f
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions mlflow/store/artifact/s3_artifact_repo.py
@@ -1,3 +1,4 @@
from functools import lru_cache
import os
from mimetypes import guess_type

Expand All @@ -11,6 +12,34 @@
from mlflow.utils.file_utils import relative_path_to_artifact_path


@lru_cache(maxsize=1024)
def _get_boto3_client(signature_version, s3_endpoint_url, verify):
"""Returns a boto3 client, caching to avoid extra boto3 verify calls.
This method is outside of the S3ArtifactRepository as it is
agnostic and could be used by other instances.
`maxsize` set to avoid excessive memory consmption in the case
a user has dynamic endpoints (intentionally or as a bug)."""

import boto3
from botocore.client import Config

# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED

return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)


class S3ArtifactRepository(ArtifactRepository):
"""Stores artifacts on Amazon S3."""

Expand All @@ -36,9 +65,6 @@ def get_s3_file_upload_extra_args():
return None

def _get_s3_client(self):
import boto3
from botocore.client import Config

s3_endpoint_url = os.environ.get("MLFLOW_S3_ENDPOINT_URL")
ignore_tls = os.environ.get("MLFLOW_S3_IGNORE_TLS")

Expand All @@ -54,18 +80,8 @@ def _get_s3_client(self):
# NOTE: If you need to specify this env variable, please file an issue at
# https://github.com/mlflow/mlflow/issues so we know your use-case!
signature_version = os.environ.get("MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION", "s3v4")
# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED
return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)

return _get_boto3_client(signature_version, s3_endpoint_url, verify)

def _upload_file(self, s3_client, local_file, bucket, key):
extra_args = dict()
Expand Down

0 comments on commit 943eb8f

Please sign in to comment.