Skip to content

Commit

Permalink
PYTHON-3614 KMS timeout errors should always have exc.timeout==True, …
Browse files Browse the repository at this point in the history
…other fixes
  • Loading branch information
ShaneHarvey committed Apr 28, 2023
1 parent f86a08d commit 05dce52
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
35 changes: 21 additions & 14 deletions pymongo/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
ServerSelectionTimeoutError,
)
from pymongo.mongo_client import MongoClient
from pymongo.network import BLOCKING_IO_ERRORS
from pymongo.operations import UpdateOne
from pymongo.pool import PoolOptions, _configured_socket
from pymongo.pool import PoolOptions, _configured_socket, _raise_connection_failure
from pymongo.read_concern import ReadConcern
from pymongo.results import BulkWriteResult, DeleteResult
from pymongo.ssl_support import get_ssl_context
Expand Down Expand Up @@ -139,20 +140,26 @@ def kms_request(self, kms_context):
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
conn = _configured_socket((host, port), opts)
try:
conn.sendall(message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
finally:
conn.close()
conn = _configured_socket((host, port), opts)
try:
conn.sendall(message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out")
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

def collection_info(self, database, filter):
"""Get the collection info for a namespace.
Expand Down
3 changes: 0 additions & 3 deletions test/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -2206,9 +2206,6 @@ def setUp(self):
self.key1_id = self.key1_document["_id"]
self.db = self.client.test_queryable_encryption
self.client.drop_database(self.db)
self.db.command("create", self.encrypted_fields["escCollection"])
self.db.command("create", self.encrypted_fields["eccCollection"])
self.db.command("create", self.encrypted_fields["ecocCollection"])
self.db.command("create", "explicit_encryption", encryptedFields=self.encrypted_fields)
key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document)
self.addCleanup(key_vault.drop)
Expand Down
4 changes: 3 additions & 1 deletion test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,9 @@ def process_error(self, exception, spec):

if is_timeout_error:
self.assertIsInstance(exception, PyMongoError)
self.assertTrue(exception.timeout, msg=exception)
if not exception.timeout:
# Re-raise the exception for better diagnostics.
raise exception

if error_contains:
if isinstance(exception, BulkWriteError):
Expand Down
30 changes: 20 additions & 10 deletions test/utils_spec_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,22 +336,24 @@ def _run_op(self, sessions, collection, op, in_with_transaction):
if expect_error(op):
with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context:
out = self.run_operation(sessions, collection, op.copy())
exc = context.exception
if expect_error_message(expected_result):
if isinstance(context.exception, BulkWriteError):
errmsg = str(context.exception.details).lower()
if isinstance(exc, BulkWriteError):
errmsg = str(exc.details).lower()
else:
errmsg = str(context.exception).lower()
errmsg = str(exc).lower()
self.assertIn(expected_result["errorContains"].lower(), errmsg)
if expect_error_code(expected_result):
self.assertEqual(
expected_result["errorCodeName"], context.exception.details.get("codeName")
)
self.assertEqual(expected_result["errorCodeName"], exc.details.get("codeName"))
if expect_error_labels_contain(expected_result):
self.assertErrorLabelsContain(
context.exception, expected_result["errorLabelsContain"]
)
self.assertErrorLabelsContain(exc, expected_result["errorLabelsContain"])
if expect_error_labels_omit(expected_result):
self.assertErrorLabelsOmit(context.exception, expected_result["errorLabelsOmit"])
self.assertErrorLabelsOmit(exc, expected_result["errorLabelsOmit"])
if expect_timeout_error(expected_result):
self.assertIsInstance(exc, PyMongoError)
if not exc.timeout:
# Re-raise the exception for better diagnostics.
raise exc

# Reraise the exception if we're in the with_transaction
# callback.
Expand Down Expand Up @@ -617,6 +619,13 @@ def expect_error_labels_omit(expected_result):
return False


def expect_timeout_error(expected_result):
if isinstance(expected_result, dict):
return expected_result["isTimeoutError"]

return False


def expect_error(op):
expected_result = op.get("result")
return (
Expand All @@ -625,6 +634,7 @@ def expect_error(op):
or expect_error_code(expected_result)
or expect_error_labels_contain(expected_result)
or expect_error_labels_omit(expected_result)
or expect_timeout_error(expected_result)
)


Expand Down

0 comments on commit 05dce52

Please sign in to comment.