Skip to content

Commit

Permalink
Caching boto client to improve artifact download speed (#4695)
Browse files Browse the repository at this point in the history
* Caching boto client

Signed-off-by: Samuel Hinton <sh@arenko.group>

* Adding five minute cache expiry to handle potential temp boto3 endpoint urls

Signed-off-by: Samuel Hinton <sh@arenko.group>

* teimstamp is not used on purpose

* Fixing tests and adding sign off

Signed-off-by: Samuel Hinton <samuelreay@gmail.com>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Fixing tests

* fix tests

Signed-off-by: harupy <hkawamura0130@gmail.com>

* lint

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Autoformat: https://github.com/mlflow/mlflow/actions/runs/1528968619

Signed-off-by: mlflow-automation <mlflow-automation@users.noreply.github.com>

Co-authored-by: Samuel Hinton <sh@arenko.group>
Co-authored-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: harupy <hkawamura0130@gmail.com>
Co-authored-by: mlflow-automation <mlflow-automation@users.noreply.github.com>
  • Loading branch information
5 people committed Dec 2, 2021
1 parent 1ac1f3a commit 19a82fe
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 16 deletions.
67 changes: 52 additions & 15 deletions mlflow/store/artifact/s3_artifact_repo.py
@@ -1,3 +1,5 @@
from datetime import datetime
from functools import lru_cache
import os
from mimetypes import guess_type

Expand All @@ -11,6 +13,51 @@
from mlflow.utils.file_utils import relative_path_to_artifact_path


_MAX_CACHE_SECONDS = 300


def _get_utcnow_timestamp():
return datetime.utcnow().timestamp()


@lru_cache(maxsize=64)
def _cached_get_s3_client(
signature_version,
s3_endpoint_url,
verify,
timestamp,
): # pylint: disable=unused-argument
"""Returns a boto3 client, caching to avoid extra boto3 verify calls.
This method is outside of the S3ArtifactRepository as it is
agnostic and could be used by other instances.
`maxsize` set to avoid excessive memory consmption in the case
a user has dynamic endpoints (intentionally or as a bug).
Some of the boto3 endpoint urls, in very edge cases, might expire
after twelve hours as that is the current expiration time. To ensure
we throw an error on verification instead of using an expired endpoint
we utilise the `timestamp` parameter to invalidate cache.
"""
import boto3
from botocore.client import Config

# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED

return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)


class S3ArtifactRepository(ArtifactRepository):
"""Stores artifacts on Amazon S3."""

Expand All @@ -36,9 +83,6 @@ def get_s3_file_upload_extra_args():
return None

def _get_s3_client(self):
import boto3
from botocore.client import Config

s3_endpoint_url = os.environ.get("MLFLOW_S3_ENDPOINT_URL")
ignore_tls = os.environ.get("MLFLOW_S3_IGNORE_TLS")

Expand All @@ -54,18 +98,11 @@ def _get_s3_client(self):
# NOTE: If you need to specify this env variable, please file an issue at
# https://github.com/mlflow/mlflow/issues so we know your use-case!
signature_version = os.environ.get("MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION", "s3v4")
# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED
return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)

# Invalidate cache every `_MAX_CACHE_SECONDS`
timestamp = int(_get_utcnow_timestamp() / _MAX_CACHE_SECONDS)

return _cached_get_s3_client(signature_version, s3_endpoint_url, verify, timestamp)

def _upload_file(self, s3_client, local_file, bucket, key):
extra_args = dict()
Expand Down
49 changes: 48 additions & 1 deletion tests/store/artifact/test_s3_artifact_repo.py
@@ -1,11 +1,16 @@
import os
import posixpath
import tarfile
from datetime import datetime

import pytest

from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import (
S3ArtifactRepository,
_cached_get_s3_client,
_MAX_CACHE_SECONDS,
)

from tests.helper_functions import set_boto_credentials # pylint: disable=unused-import
from tests.helper_functions import mock_s3_bucket # pylint: disable=unused-import
Expand All @@ -18,6 +23,11 @@ def s3_artifact_root(mock_s3_bucket):
return "s3://{bucket_name}".format(bucket_name=mock_s3_bucket)


@pytest.fixture(autouse=True)
def reset_cached_get_s3_client():
_cached_get_s3_client.cache_clear()


def teardown_function():
if "MLFLOW_S3_UPLOAD_EXTRA_ARGS" in os.environ:
del os.environ["MLFLOW_S3_UPLOAD_EXTRA_ARGS"]
Expand Down Expand Up @@ -55,6 +65,43 @@ def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root, tmpdir)
assert response.get("ContentEncoding") is None


def test_get_s3_client_hits_cache(s3_artifact_root):
# pylint: disable=no-value-for-parameter
repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"))
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 0
assert cache_info.misses == 1
assert cache_info.currsize == 1

repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 1
assert cache_info.currsize == 1

with mock.patch.dict(
"os.environ",
{"MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION": "s3v2"},
clear=True,
):
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 2
assert cache_info.currsize == 2

with mock.patch(
"mlflow.store.artifact.s3_artifact_repo._get_utcnow_timestamp",
return_value=datetime.utcnow().timestamp() + _MAX_CACHE_SECONDS,
):
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 3
assert cache_info.currsize == 3


@pytest.mark.parametrize("ignore_tls_env, verify", [("", None), ("true", False), ("false", None)])
def test_get_s3_client_verify_param_set_correctly(s3_artifact_root, ignore_tls_env, verify):
from unittest.mock import ANY
Expand Down

0 comments on commit 19a82fe

Please sign in to comment.