diff --git a/mlflow/store/artifact/s3_artifact_repo.py b/mlflow/store/artifact/s3_artifact_repo.py index fec9c46a787fb..0cddd99cf2104 100644 --- a/mlflow/store/artifact/s3_artifact_repo.py +++ b/mlflow/store/artifact/s3_artifact_repo.py @@ -1,3 +1,5 @@ +from datetime import datetime +from functools import lru_cache import os from mimetypes import guess_type @@ -11,6 +13,51 @@ from mlflow.utils.file_utils import relative_path_to_artifact_path +_MAX_CACHE_SECONDS = 300 + + +def _get_utcnow_timestamp(): + return datetime.utcnow().timestamp() + + +@lru_cache(maxsize=64) +def _cached_get_s3_client( + signature_version, + s3_endpoint_url, + verify, + timestamp, +): # pylint: disable=unused-argument + """Returns a boto3 client, caching to avoid extra boto3 verify calls. + + This method is outside of the S3ArtifactRepository as it is + agnostic and could be used by other instances. + + `maxsize` set to avoid excessive memory consmption in the case + a user has dynamic endpoints (intentionally or as a bug). + + Some of the boto3 endpoint urls, in very edge cases, might expire + after twelve hours as that is the current expiration time. To ensure + we throw an error on verification instead of using an expired endpoint + we utilise the `timestamp` parameter to invalidate cache. + """ + import boto3 + from botocore.client import Config + + # Making it possible to access public S3 buckets + # Workaround for https://github.com/boto/botocore/issues/2442 + if signature_version.lower() == "unsigned": + from botocore import UNSIGNED + + signature_version = UNSIGNED + + return boto3.client( + "s3", + config=Config(signature_version=signature_version), + endpoint_url=s3_endpoint_url, + verify=verify, + ) + + class S3ArtifactRepository(ArtifactRepository): """Stores artifacts on Amazon S3.""" @@ -36,9 +83,6 @@ def get_s3_file_upload_extra_args(): return None def _get_s3_client(self): - import boto3 - from botocore.client import Config - s3_endpoint_url = os.environ.get("MLFLOW_S3_ENDPOINT_URL") ignore_tls = os.environ.get("MLFLOW_S3_IGNORE_TLS") @@ -54,18 +98,11 @@ def _get_s3_client(self): # NOTE: If you need to specify this env variable, please file an issue at # https://github.com/mlflow/mlflow/issues so we know your use-case! signature_version = os.environ.get("MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION", "s3v4") - # Making it possible to access public S3 buckets - # Workaround for https://github.com/boto/botocore/issues/2442 - if signature_version.lower() == "unsigned": - from botocore import UNSIGNED - - signature_version = UNSIGNED - return boto3.client( - "s3", - config=Config(signature_version=signature_version), - endpoint_url=s3_endpoint_url, - verify=verify, - ) + + # Invalidate cache every `_MAX_CACHE_SECONDS` + timestamp = int(_get_utcnow_timestamp() / _MAX_CACHE_SECONDS) + + return _cached_get_s3_client(signature_version, s3_endpoint_url, verify, timestamp) def _upload_file(self, s3_client, local_file, bucket, key): extra_args = dict() diff --git a/tests/store/artifact/test_s3_artifact_repo.py b/tests/store/artifact/test_s3_artifact_repo.py index bb8a84f0e2406..fab28309d83b3 100644 --- a/tests/store/artifact/test_s3_artifact_repo.py +++ b/tests/store/artifact/test_s3_artifact_repo.py @@ -1,11 +1,16 @@ import os import posixpath import tarfile +from datetime import datetime import pytest from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository -from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository +from mlflow.store.artifact.s3_artifact_repo import ( + S3ArtifactRepository, + _cached_get_s3_client, + _MAX_CACHE_SECONDS, +) from tests.helper_functions import set_boto_credentials # pylint: disable=unused-import from tests.helper_functions import mock_s3_bucket # pylint: disable=unused-import @@ -18,6 +23,11 @@ def s3_artifact_root(mock_s3_bucket): return "s3://{bucket_name}".format(bucket_name=mock_s3_bucket) +@pytest.fixture(autouse=True) +def reset_cached_get_s3_client(): + _cached_get_s3_client.cache_clear() + + def teardown_function(): if "MLFLOW_S3_UPLOAD_EXTRA_ARGS" in os.environ: del os.environ["MLFLOW_S3_UPLOAD_EXTRA_ARGS"] @@ -55,6 +65,43 @@ def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root, tmpdir) assert response.get("ContentEncoding") is None +def test_get_s3_client_hits_cache(s3_artifact_root): + # pylint: disable=no-value-for-parameter + repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path")) + repo._get_s3_client() + cache_info = _cached_get_s3_client.cache_info() + assert cache_info.hits == 0 + assert cache_info.misses == 1 + assert cache_info.currsize == 1 + + repo._get_s3_client() + cache_info = _cached_get_s3_client.cache_info() + assert cache_info.hits == 1 + assert cache_info.misses == 1 + assert cache_info.currsize == 1 + + with mock.patch.dict( + "os.environ", + {"MLFLOW_EXPERIMENTAL_S3_SIGNATURE_VERSION": "s3v2"}, + clear=True, + ): + repo._get_s3_client() + cache_info = _cached_get_s3_client.cache_info() + assert cache_info.hits == 1 + assert cache_info.misses == 2 + assert cache_info.currsize == 2 + + with mock.patch( + "mlflow.store.artifact.s3_artifact_repo._get_utcnow_timestamp", + return_value=datetime.utcnow().timestamp() + _MAX_CACHE_SECONDS, + ): + repo._get_s3_client() + cache_info = _cached_get_s3_client.cache_info() + assert cache_info.hits == 1 + assert cache_info.misses == 3 + assert cache_info.currsize == 3 + + @pytest.mark.parametrize("ignore_tls_env, verify", [("", None), ("true", False), ("false", None)]) def test_get_s3_client_verify_param_set_correctly(s3_artifact_root, ignore_tls_env, verify): from unittest.mock import ANY