-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
s3.py
253 lines (203 loc) · 8.37 KB
/
s3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import functools
import os
import threading
from collections import defaultdict
from funcy import cached_property, wrap_prop
from dvc.scheme import Schemes
from ._callback import DEFAULT_CALLBACK
from .fsspec_wrapper import ObjectFSWrapper
_AWS_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".aws", "config")
# pylint:disable=abstract-method
class BaseS3FileSystem(ObjectFSWrapper):
scheme = Schemes.S3
REQUIRES = {"s3fs": "s3fs", "boto3": "boto3"}
PARAM_CHECKSUM = "etag"
_GRANTS = {
"grant_full_control": "GrantFullControl",
"grant_read": "GrantRead",
"grant_read_acp": "GrantReadACP",
"grant_write_acp": "GrantWriteACP",
}
_TRANSFER_CONFIG_ALIASES = {
"max_queue_size": "max_io_queue",
"max_concurrent_requests": "max_concurrency",
"multipart_threshold": "multipart_threshold",
"multipart_chunksize": "multipart_chunksize",
}
def _split_s3_config(self, s3_config):
"""Splits the general s3 config into 2 different config
objects, one for transfer.TransferConfig and other is the
general session config"""
from boto3.s3.transfer import TransferConfig
from dvc.utils import conversions
config, transfer_config = {}, {}
for key, value in s3_config.items():
if key in self._TRANSFER_CONFIG_ALIASES:
if key in {"multipart_chunksize", "multipart_threshold"}:
# cast human readable sizes (like 24MiB) to integers
value = conversions.human_readable_to_bytes(value)
else:
value = int(value)
transfer_config[self._TRANSFER_CONFIG_ALIASES[key]] = value
else:
config[key] = value
# pylint: disable=attribute-defined-outside-init
self._transfer_config = TransferConfig(**transfer_config)
return config
def _load_aws_config_file(self, profile):
from botocore.configloader import load_config
# pylint: disable=attribute-defined-outside-init
self._transfer_config = None
config_path = os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH)
if not os.path.exists(config_path):
return {}
config = load_config(config_path)
profile_config = config["profiles"].get(profile or "default")
if not profile_config:
return {}
s3_config = profile_config.get("s3", {})
return self._split_s3_config(s3_config)
def _prepare_credentials(self, **config):
from s3fs.utils import SSEParams
from dvc.config import ConfigError
from dvc.utils.flatten import flatten, unflatten
login_info = defaultdict(dict)
# credentials
login_info["key"] = config.get("access_key_id")
login_info["secret"] = config.get("secret_access_key")
login_info["token"] = config.get("session_token")
# session configuration
login_info["profile"] = config.get("profile")
login_info["use_ssl"] = config.get("use_ssl", True)
# extra client configuration
client = login_info["client_kwargs"]
client["region_name"] = config.get("region")
client["endpoint_url"] = config.get("endpointurl")
client["verify"] = config.get("ssl_verify")
# timeout configuration
config_kwargs = login_info["config_kwargs"]
config_kwargs["read_timeout"] = config.get("read_timeout")
config_kwargs["connect_timeout"] = config.get("connect_timeout")
# encryptions
additional = login_info["s3_additional_kwargs"]
sse_customer_key = None
if config.get("sse_customer_key"):
if config.get("sse_kms_key_id"):
raise ConfigError(
"`sse_kms_key_id` and `sse_customer_key` AWS S3 config "
"options are mutually exclusive"
)
import base64
sse_customer_key = base64.b64decode(config.get("sse_customer_key"))
sse_customer_algorithm = config.get("sse_customer_algorithm")
if not sse_customer_algorithm:
sse_customer_algorithm = "AES256"
sse_params = SSEParams(
server_side_encryption=config.get("sse"),
sse_customer_algorithm=sse_customer_algorithm,
sse_customer_key=sse_customer_key,
sse_kms_key_id=config.get("sse_kms_key_id"),
)
additional.update(sse_params.to_kwargs())
additional["ACL"] = config.get("acl")
for grant_option, grant_key in self._GRANTS.items():
if config.get(grant_option):
if additional["ACL"]:
raise ConfigError(
"`acl` and `grant_*` AWS S3 config options "
"are mutually exclusive"
)
additional[grant_key] = config[grant_option]
# config kwargs
session_config = login_info["config_kwargs"]
session_config["s3"] = self._load_aws_config_file(
login_info["profile"]
)
shared_creds = config.get("credentialpath")
if shared_creds:
os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds)
if (
client["region_name"] is None
and session_config["s3"].get("region_name") is None
and os.getenv("AWS_REGION") is None
):
# Enable bucket region caching
login_info["cache_regions"] = config.get("cache_regions", True)
config_path = config.get("configpath")
if config_path:
os.environ.setdefault("AWS_CONFIG_FILE", config_path)
return unflatten(
{
key: value
for key, value in flatten(login_info).items()
if value is not None
}
)
@wrap_prop(threading.Lock())
@cached_property
def fs(self):
from s3fs import S3FileSystem as _S3FileSystem
return _S3FileSystem(**self.fs_args)
@classmethod
def _strip_protocol(cls, path: str) -> str:
from fsspec.utils import infer_storage_options
return infer_storage_options(path)["path"]
def unstrip_protocol(self, path):
return "s3://" + path.lstrip("/")
def _translate_exceptions(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as exc:
from s3fs.errors import translate_boto_error
raise translate_boto_error(exc)
return wrapper
class S3FileSystem(BaseS3FileSystem): # pylint:disable=abstract-method
@wrap_prop(threading.Lock())
@cached_property
def s3(self):
import boto3
login_info = self.fs_args
client_kwargs = login_info.get("client_kwargs", {})
session_opts = {
"profile_name": login_info.get("profile"),
"region_name": client_kwargs.get("region_name"),
}
if "key" in login_info:
session_opts["aws_access_key_id"] = login_info["key"]
if "secret" in login_info:
session_opts["aws_secret_access_key"] = login_info["secret"]
if "token" in login_info:
session_opts["aws_session_token"] = login_info["token"]
session = boto3.session.Session(**session_opts)
return session.resource(
"s3",
endpoint_url=client_kwargs.get("endpoint_url"),
use_ssl=login_info["use_ssl"],
verify=client_kwargs.get("verify"),
)
def _get_obj(self, path):
bucket_name, key, _ = self.fs.split_path(path)
bucket = self.s3.Bucket(bucket_name)
return bucket.Object(key)
@_translate_exceptions
def put_file(
self, from_file, to_info, callback=DEFAULT_CALLBACK, **kwargs
):
callback.set_size(os.path.getsize(from_file))
obj = self._get_obj(to_info)
obj.upload_file(
from_file,
Callback=callback.relative_update,
ExtraArgs=self.fs_args.get("s3_additional_kwargs"),
Config=self._transfer_config,
)
self.fs.invalidate_cache(self.path.parent(to_info))
@_translate_exceptions
def get_file(
self, from_info, to_info, callback=DEFAULT_CALLBACK, **kwargs
):
obj = self._get_obj(from_info)
callback.set_size(obj.content_length)
obj.download_file(to_info, Callback=callback.relative_update)