Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching boto client to improve artifact download speed #4695

Merged
merged 12 commits into from Dec 2, 2021
67 changes: 52 additions & 15 deletions mlflow/store/artifact/s3_artifact_repo.py
@@ -1,3 +1,5 @@
from datetime import datetime
from functools import lru_cache
import os
from mimetypes import guess_type

Expand All @@ -11,6 +13,51 @@
from mlflow.utils.file_utils import relative_path_to_artifact_path


_MAX_CACHE_SECONDS = 300


def _get_utcnow_timestamp():
return datetime.utcnow().timestamp()


@lru_cache(maxsize=64)
def _cached_get_s3_client(
signature_version,
s3_endpoint_url,
verify,
timestamp,
): # pylint: disable=unused-argument
"""Returns a boto3 client, caching to avoid extra boto3 verify calls.

This method is outside of the S3ArtifactRepository as it is
agnostic and could be used by other instances.

`maxsize` set to avoid excessive memory consmption in the case
a user has dynamic endpoints (intentionally or as a bug).

Some of the boto3 endpoint urls, in very edge cases, might expire
after twelve hours as that is the current expiration time. To ensure
we throw an error on verification instead of using an expired endpoint
we utilise the `timestamp` parameter to invalidate cache.
"""
import boto3
from botocore.client import Config

# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED

return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)


class S3ArtifactRepository(ArtifactRepository):
"""Stores artifacts on Amazon S3."""

Expand All @@ -36,9 +83,6 @@ def get_s3_file_upload_extra_args():
return None

def _get_s3_client(self):
import boto3
from botocore.client import Config

s3_endpoint_url = os.environ.get("MLFLOW_S3_ENDPOINT_URL")
ignore_tls = os.environ.get("MLFLOW_S3_IGNORE_TLS")

Expand All @@ -54,18 +98,11 @@ def _get_s3_client(self):
# NOTE: If you need to specify this env variable, please file an issue at
# https://github.com/mlflow/mlflow/issues so we know your use-case!
signature_version = os.environ.get("MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION", "s3v4")
# Making it possible to access public S3 buckets
# Workaround for https://github.com/boto/botocore/issues/2442
if signature_version.lower() == "unsigned":
from botocore import UNSIGNED

signature_version = UNSIGNED
return boto3.client(
"s3",
config=Config(signature_version=signature_version),
endpoint_url=s3_endpoint_url,
verify=verify,
)

# Invalidate cache every `_MAX_CACHE_SECONDS`
timestamp = int(_get_utcnow_timestamp() / _MAX_CACHE_SECONDS)

return _cached_get_s3_client(signature_version, s3_endpoint_url, verify, timestamp)

def _upload_file(self, s3_client, local_file, bucket, key):
extra_args = dict()
Expand Down
49 changes: 48 additions & 1 deletion tests/store/artifact/test_s3_artifact_repo.py
@@ -1,11 +1,16 @@
import os
import posixpath
import tarfile
from datetime import datetime

import pytest

from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
from mlflow.store.artifact.s3_artifact_repo import (
S3ArtifactRepository,
_cached_get_s3_client,
_MAX_CACHE_SECONDS,
)

from tests.helper_functions import set_boto_credentials # pylint: disable=unused-import
from tests.helper_functions import mock_s3_bucket # pylint: disable=unused-import
Expand All @@ -18,6 +23,11 @@ def s3_artifact_root(mock_s3_bucket):
return "s3://{bucket_name}".format(bucket_name=mock_s3_bucket)


@pytest.fixture(autouse=True)
def reset_cached_get_s3_client():
_cached_get_s3_client.cache_clear()


def teardown_function():
if "MLFLOW_S3_UPLOAD_EXTRA_ARGS" in os.environ:
del os.environ["MLFLOW_S3_UPLOAD_EXTRA_ARGS"]
Expand Down Expand Up @@ -55,6 +65,43 @@ def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root, tmpdir)
assert response.get("ContentEncoding") is None


def test_get_s3_client_hits_cache(s3_artifact_root):
# pylint: disable=no-value-for-parameter
repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"))
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 0
assert cache_info.misses == 1
assert cache_info.currsize == 1

repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 1
assert cache_info.currsize == 1

with mock.patch.dict(
"os.environ",
{"MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION": "s3v2"},
clear=True,
):
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 2
assert cache_info.currsize == 2

with mock.patch(
"mlflow.store.artifact.s3_artifact_repo._get_utcnow_timestamp",
return_value=datetime.utcnow().timestamp() + _MAX_CACHE_SECONDS,
):
repo._get_s3_client()
cache_info = _cached_get_s3_client.cache_info()
assert cache_info.hits == 1
assert cache_info.misses == 3
assert cache_info.currsize == 3


@pytest.mark.parametrize("ignore_tls_env, verify", [("", None), ("true", False), ("false", None)])
def test_get_s3_client_verify_param_set_correctly(s3_artifact_root, ignore_tls_env, verify):
from unittest.mock import ANY
Expand Down