Skip to content

Commit

Permalink
Use contexthandlers when working with files in tests (#1384)
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneier committed Apr 22, 2024
1 parent 74864ec commit 969528b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 122 deletions.
8 changes: 4 additions & 4 deletions tests/test_dropbox.py
Expand Up @@ -169,17 +169,17 @@ def setUp(self, *args):
return_value=(FILE_METADATA_MOCK, RESPONSE_200_MOCK),
)
def test_read(self, *args):
file = self.storage._open("foo.txt")
self.assertEqual(file.read(), b"bar")
with self.storage.open("foo.txt") as file:
self.assertEqual(file.read(), b"bar")

@mock.patch(
"dropbox.Dropbox.files_download",
return_value=(FILE_METADATA_MOCK, RESPONSE_500_MOCK),
)
def test_server_bad_response(self, *args):
with self.assertRaises(dropbox.DropboxStorageException):
file = self.storage._open("foo.txt")
file.read()
with self.storage.open("foo.txt") as file:
file.read()


@mock.patch("dropbox.Dropbox.files_list_folder", return_value=FILES_EMPTY_MOCK)
Expand Down
24 changes: 14 additions & 10 deletions tests/test_gcloud.py
Expand Up @@ -36,23 +36,27 @@ def test_open_read(self):
"""
data = b"This is some test read data."

f = self.storage.open(self.filename)
self.storage._client.bucket.assert_called_with(self.bucket_name)
self.storage._bucket.get_blob.assert_called_with(self.filename, chunk_size=None)
with self.storage.open(self.filename) as f:
self.storage._client.bucket.assert_called_with(self.bucket_name)
self.storage._bucket.get_blob.assert_called_with(
self.filename, chunk_size=None
)

f.blob.download_to_file = lambda tmpfile: tmpfile.write(data)
self.assertEqual(f.read(), data)
f.blob.download_to_file = lambda tmpfile: tmpfile.write(data)
self.assertEqual(f.read(), data)

def test_open_read_num_bytes(self):
data = b"This is some test read data."
num_bytes = 10

f = self.storage.open(self.filename)
self.storage._client.bucket.assert_called_with(self.bucket_name)
self.storage._bucket.get_blob.assert_called_with(self.filename, chunk_size=None)
with self.storage.open(self.filename) as f:
self.storage._client.bucket.assert_called_with(self.bucket_name)
self.storage._bucket.get_blob.assert_called_with(
self.filename, chunk_size=None
)

f.blob.download_to_file = lambda tmpfile: tmpfile.write(data)
self.assertEqual(f.read(num_bytes), data[0:num_bytes])
f.blob.download_to_file = lambda tmpfile: tmpfile.write(data)
self.assertEqual(f.read(num_bytes), data[0:num_bytes])

def test_open_read_nonexistent(self):
self.storage._bucket = mock.MagicMock()
Expand Down
202 changes: 94 additions & 108 deletions tests/test_s3.py
Expand Up @@ -300,10 +300,9 @@ def test_storage_open_read_string(self):
Test opening a file in "r" mode (ie reading as string, not bytes)
"""
name = "test_open_read_string.txt"
file = self.storage.open(name, "r")
content_str = file.read()
self.assertEqual(content_str, "")
file.close()
with self.storage.open(name, "r") as file:
content_str = file.read()
self.assertEqual(content_str, "")

def test_storage_open_readlines(self):
"""
Expand All @@ -312,17 +311,15 @@ def test_storage_open_readlines(self):
name = "test_open_readlines_string.txt"
with io.BytesIO() as temp_file:
temp_file.write(b"line1\nline2")
file = self.storage.open(name, "r")
file._file = temp_file

content_lines = file.readlines()
self.assertEqual(content_lines, ["line1\n", "line2"])

temp_file.seek(0)
file = self.storage.open(name, "rb")
file._file = temp_file
content_lines = file.readlines()
self.assertEqual(content_lines, [b"line1\n", b"line2"])
with self.storage.open(name, "r") as f1:
f1._file = temp_file
content_lines = f1.readlines()
self.assertEqual(content_lines, ["line1\n", "line2"])
temp_file.seek(0)
with self.storage.open(name, "rb") as f2:
f2._file = temp_file
content_lines = f2.readlines()
self.assertEqual(content_lines, [b"line1\n", b"line2"])

def test_storage_open_write(self):
"""
Expand All @@ -338,26 +335,24 @@ def test_storage_open_write(self):
"ACL": "public-read",
}

file = self.storage.open(name, "w")
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name
with self.storage.open(name, "w") as file:
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name

multipart = obj.initiate_multipart_upload.return_value
multipart.Part.return_value.upload.side_effect = [
{"ETag": "123"},
]
file.write(content)
obj.initiate_multipart_upload.assert_called_with(
ACL="public-read",
ContentType="text/plain",
ServerSideEncryption="AES256",
StorageClass="REDUCED_REDUNDANCY",
)
multipart = obj.initiate_multipart_upload.return_value
multipart.Part.return_value.upload.side_effect = [
{"ETag": "123"},
]
file.write(content)
obj.initiate_multipart_upload.assert_called_with(
ACL="public-read",
ContentType="text/plain",
ServerSideEncryption="AES256",
StorageClass="REDUCED_REDUNDANCY",
)

# Save the internal file before closing
file.close()
multipart.Part.assert_called_with(1)
part = multipart.Part.return_value
part.upload.assert_called_with(Body=content.encode())
Expand All @@ -369,13 +364,12 @@ def test_write_bytearray(self):
"""Test that bytearray write exactly (no extra "bytearray" from stringify)."""
name = "saved_file.bin"
content = bytearray(b"content")
file = self.storage.open(name, "wb")
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name
bytes_written = file.write(content)
self.assertEqual(len(content), bytes_written)
file.close()
with self.storage.open(name, "wb") as file:
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name
bytes_written = file.write(content)
self.assertEqual(len(content), bytes_written)

def test_storage_open_no_write(self):
"""
Expand All @@ -391,18 +385,16 @@ def test_storage_open_no_write(self):
"StorageClass": "REDUCED_REDUNDANCY",
}

file = self.storage.open(name, "w")
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
obj.load.side_effect = ClientError(
{"Error": {}, "ResponseMetadata": {"HTTPStatusCode": 404}}, "head_bucket"
)

# Set the name of the mock object
obj.key = name
with self.storage.open(name, "w"):
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
obj.load.side_effect = ClientError(
{"Error": {}, "ResponseMetadata": {"HTTPStatusCode": 404}},
"head_bucket",
)

# Save the internal file before closing
file.close()
# Set the name of the mock object
obj.key = name

obj.load.assert_called_once_with()
obj.put.assert_called_once_with(
Expand All @@ -424,15 +416,12 @@ def test_storage_open_no_overwrite_existing(self):
"StorageClass": "REDUCED_REDUNDANCY",
}

file = self.storage.open(name, "w")
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
with self.storage.open(name, "w"):
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value

# Set the name of the mock object
obj.key = name

# Save the internal file before closing
file.close()
# Set the name of the mock object
obj.key = name

obj.load.assert_called_once_with()
obj.put.assert_not_called()
Expand All @@ -449,39 +438,37 @@ def test_storage_write_beyond_buffer_size(self):
"StorageClass": "REDUCED_REDUNDANCY",
}

file = self.storage.open(name, "w")
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name

# Initiate the multipart upload
file.write("")
obj.initiate_multipart_upload.assert_called_with(
ContentType="text/plain",
ServerSideEncryption="AES256",
StorageClass="REDUCED_REDUNDANCY",
)
multipart = obj.initiate_multipart_upload.return_value

# Write content at least twice as long as the buffer size
written_content = ""
counter = 1
multipart.Part.return_value.upload.side_effect = [
{"ETag": "123"},
{"ETag": "456"},
]
while len(written_content) < 2 * file.buffer_size:
content = "hello, aws {counter}\n".format(counter=counter)
# Write more than just a few bytes in each iteration to keep the
# test reasonably fast
content += "*" * int(file.buffer_size / 10)
file.write(content)
written_content += content
counter += 1
with self.storage.open(name, "w") as file:
self.storage.bucket.Object.assert_called_with(name)
obj = self.storage.bucket.Object.return_value
# Set the name of the mock object
obj.key = name

# Initiate the multipart upload
file.write("")
obj.initiate_multipart_upload.assert_called_with(
ContentType="text/plain",
ServerSideEncryption="AES256",
StorageClass="REDUCED_REDUNDANCY",
)
multipart = obj.initiate_multipart_upload.return_value

# Write content at least twice as long as the buffer size
written_content = ""
counter = 1
multipart.Part.return_value.upload.side_effect = [
{"ETag": "123"},
{"ETag": "456"},
]
while len(written_content) < 2 * file.buffer_size:
content = "hello, aws {counter}\n".format(counter=counter)
# Write more than just a few bytes in each iteration to keep the
# test reasonably fast
content += "*" * int(file.buffer_size / 10)
file.write(content)
written_content += content
counter += 1

# Save the internal file before closing
file.close()
self.assertListEqual(
multipart.Part.call_args_list, [mock.call(1), mock.call(2)]
)
Expand Down Expand Up @@ -1025,23 +1012,22 @@ def test_loading_ssec(self):
)

def test_closed(self):
f = s3.S3File("test", "wb", self.storage)

with self.subTest("after init"):
self.assertFalse(f.closed)

with self.subTest("after file access"):
# Ensure _get_file has been called
f.file
self.assertFalse(f.closed)

with self.subTest("after close"):
f.close()
self.assertTrue(f.closed)

with self.subTest("reopening"):
f.file
self.assertFalse(f.closed)
with s3.S3File("test", "wb", self.storage) as f:
with self.subTest("after init"):
self.assertFalse(f.closed)

with self.subTest("after file access"):
# Ensure _get_file has been called
f.file
self.assertFalse(f.closed)

with self.subTest("after close"):
f.close()
self.assertTrue(f.closed)

with self.subTest("reopening"):
f.file
self.assertFalse(f.closed)

def test_reopening(self):
f = s3.S3File("test", "wb", self.storage)
Expand Down

0 comments on commit 969528b

Please sign in to comment.