Skip to content

Commit

Permalink
S3 upload_file, download file to support path-lib objects (#2259)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonas Neubert <jonasneu@amazon.com>
  • Loading branch information
alanyee and jonemo committed Dec 16, 2022
1 parent 3f5a1b3 commit 3442539
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3-4244.json
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "s3",
"description": "s3.transfer methods accept path-like objects as input"
}
10 changes: 8 additions & 2 deletions boto3/s3/transfer.py
Expand Up @@ -122,6 +122,8 @@ def __call__(self, bytes_amount):
"""
from os import PathLike, fspath

from botocore.exceptions import ClientError
from s3transfer.exceptions import (
RetriesExceededError as S3TransferRetriesExceededError,
Expand Down Expand Up @@ -277,8 +279,10 @@ def upload_file(
:py:meth:`S3.Client.upload_file`
:py:meth:`S3.Client.upload_fileobj`
"""
if isinstance(filename, PathLike):
filename = fspath(filename)
if not isinstance(filename, str):
raise ValueError('Filename must be a string')
raise ValueError('Filename must be a string or a path-like object')

subscribers = self._get_subscribers(callback)
future = self._manager.upload(
Expand Down Expand Up @@ -309,8 +313,10 @@ def download_file(
:py:meth:`S3.Client.download_file`
:py:meth:`S3.Client.download_fileobj`
"""
if isinstance(filename, PathLike):
filename = fspath(filename)
if not isinstance(filename, str):
raise ValueError('Filename must be a string')
raise ValueError('Filename must be a string or a path-like object')

subscribers = self._get_subscribers(callback)
future = self._manager.download(
Expand Down
9 changes: 8 additions & 1 deletion tests/integration/test_s3.py
Expand Up @@ -20,6 +20,7 @@
import string
import tempfile
import threading
from pathlib import Path

from botocore.client import Config

Expand Down Expand Up @@ -386,6 +387,13 @@ def test_download_fileobj(self):

self.assertEqual(fileobj.getvalue(), b'beach')

def test_upload_via_path(self):
transfer = self.create_s3_transfer()
filename = self.files.create_file_with_size('path.txt', filesize=1024)
transfer.upload_file(Path(filename), self.bucket_name, 'path.txt')
self.addCleanup(self.delete_object, 'path.txt')
self.assertTrue(self.object_exists('path.txt'))

def test_upload_below_threshold(self):
config = boto3.s3.transfer.TransferConfig(
multipart_threshold=2 * 1024 * 1024
Expand All @@ -396,7 +404,6 @@ def test_upload_below_threshold(self):
)
transfer.upload_file(filename, self.bucket_name, 'foo.txt')
self.addCleanup(self.delete_object, 'foo.txt')

self.assertTrue(self.object_exists('foo.txt'))

def test_upload_above_threshold(self):
Expand Down
55 changes: 53 additions & 2 deletions tests/unit/s3/test_transfer.py
Expand Up @@ -10,6 +10,9 @@
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pathlib
from tempfile import NamedTemporaryFile

import pytest
from s3transfer.futures import NonThreadedExecutor
from s3transfer.manager import TransferManager
Expand Down Expand Up @@ -126,6 +129,11 @@ def setUp(self):
self.manager = mock.Mock(TransferManager(self.client))
self.transfer = S3Transfer(manager=self.manager)
self.callback = mock.Mock()
# Use NamedTempFile as source of a path string that is valid and
# realistic for the system the tests are run on. The file gets deleted
# immediately and will not actually exist while the tests are run.
with NamedTemporaryFile("w") as tmp_file:
self.file_path_str = tmp_file.name

def assert_callback_wrapped_in_subscriber(self, call_args):
subscribers = call_args[0][4]
Expand All @@ -148,16 +156,59 @@ def test_upload_file(self):
'smallfile', 'bucket', 'key', extra_args, None
)

def test_upload_file_via_path(self):
extra_args = {'ACL': 'public-read'}
self.transfer.upload_file(
pathlib.Path(self.file_path_str),
'bucket',
'key',
extra_args=extra_args,
)
self.manager.upload.assert_called_with(
self.file_path_str, 'bucket', 'key', extra_args, None
)

def test_upload_file_via_purepath(self):
extra_args = {'ACL': 'public-read'}
self.transfer.upload_file(
pathlib.PurePath(self.file_path_str),
'bucket',
'key',
extra_args=extra_args,
)
self.manager.upload.assert_called_with(
self.file_path_str, 'bucket', 'key', extra_args, None
)

def test_download_file(self):
extra_args = {
'SSECustomerKey': 'foo',
'SSECustomerAlgorithm': 'AES256',
}
self.transfer.download_file(
'bucket', 'key', '/tmp/smallfile', extra_args=extra_args
'bucket', 'key', self.file_path_str, extra_args=extra_args
)
self.manager.download.assert_called_with(
'bucket', 'key', self.file_path_str, extra_args, None
)

def test_download_file_via_path(self):
extra_args = {
'SSECustomerKey': 'foo',
'SSECustomerAlgorithm': 'AES256',
}
self.transfer.download_file(
'bucket',
'key',
pathlib.Path(self.file_path_str),
extra_args=extra_args,
)
self.manager.download.assert_called_with(
'bucket', 'key', '/tmp/smallfile', extra_args, None
'bucket',
'key',
self.file_path_str,
extra_args,
None,
)

def test_upload_wraps_callback(self):
Expand Down

0 comments on commit 3442539

Please sign in to comment.