diff --git a/src/urllib3/contrib/_securetransport/low_level.py b/src/urllib3/contrib/_securetransport/low_level.py index bb302fc601..ed8120190c 100644 --- a/src/urllib3/contrib/_securetransport/low_level.py +++ b/src/urllib3/contrib/_securetransport/low_level.py @@ -13,6 +13,7 @@ import os import re import ssl +import struct import tempfile from .bindings import CFConst, CoreFoundation, Security @@ -370,3 +371,26 @@ def _load_client_cert_chain(keychain, *paths): finally: for obj in itertools.chain(identities, certificates): CoreFoundation.CFRelease(obj) + + +TLS_PROTOCOL_VERSIONS = { + "SSLv2": (0, 2), + "SSLv3": (3, 0), + "TLSv1": (3, 1), + "TLSv1.1": (3, 2), + "TLSv1.2": (3, 3), +} + + +def _build_tls_unknown_ca_alert(version): + """ + Builds a TLS alert record for an unknown CA. + """ + ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version] + severity_fatal = 0x02 + description_unknown_ca = 0x30 + msg = struct.pack(">BB", severity_fatal, description_unknown_ca) + msg_len = len(msg) + record_type_alert = 0x15 + record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg + return record diff --git a/src/urllib3/contrib/securetransport.py b/src/urllib3/contrib/securetransport.py index 866f00d46f..ab092de67a 100644 --- a/src/urllib3/contrib/securetransport.py +++ b/src/urllib3/contrib/securetransport.py @@ -60,6 +60,7 @@ import shutil import socket import ssl +import struct import threading import weakref @@ -69,6 +70,7 @@ from ._securetransport.bindings import CoreFoundation, Security, SecurityConst from ._securetransport.low_level import ( _assert_no_error, + _build_tls_unknown_ca_alert, _cert_array_from_pem, _create_cfstring_array, _load_client_cert_chain, @@ -397,11 +399,37 @@ def _custom_validate(self, verify, trust_bundle): Called when we have set custom validation. We do this in two cases: first, when cert validation is entirely disabled; and second, when using a custom trust DB. + Raises an SSLError if the connection is not trusted. """ # If we disabled cert validation, just say: cool. if not verify: return + successes = ( + SecurityConst.kSecTrustResultUnspecified, + SecurityConst.kSecTrustResultProceed, + ) + try: + trust_result = self._evaluate_trust(trust_bundle) + if trust_result in successes: + return + reason = "error code: %d" % (trust_result,) + except Exception as e: + # Do not trust on error + reason = "exception: %r" % (e,) + + # SecureTransport does not send an alert nor shuts down the connection. + rec = _build_tls_unknown_ca_alert(self.version()) + self.socket.sendall(rec) + # close the connection immediately + # l_onoff = 1, activate linger + # l_linger = 0, linger for 0 seoncds + opts = struct.pack("ii", 1, 0) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts) + self.close() + raise ssl.SSLError("certificate verify failed, %s" % reason) + + def _evaluate_trust(self, trust_bundle): # We want data in memory, so load it up. if os.path.isfile(trust_bundle): with open(trust_bundle, "rb") as f: @@ -439,15 +467,7 @@ def _custom_validate(self, verify, trust_bundle): if cert_array is not None: CoreFoundation.CFRelease(cert_array) - # Ok, now we can look at what the result was. - successes = ( - SecurityConst.kSecTrustResultUnspecified, - SecurityConst.kSecTrustResultProceed, - ) - if trust_result.value not in successes: - raise ssl.SSLError( - "certificate verify failed, error code: %d" % trust_result.value - ) + return trust_result.value def handshake( self, diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 69d05bf758..728b4a7c2d 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -477,29 +477,31 @@ def test_assert_fingerprint_sha256(self): https_pool.request("GET", "/") def test_assert_invalid_fingerprint(self): + def _test_request(pool): + with pytest.raises(MaxRetryError) as cm: + pool.request("GET", "/", retries=0) + assert isinstance(cm.value.reason, SSLError) + return cm.value.reason + with HTTPSConnectionPool( - "127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA + self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA ) as https_pool: + https_pool.assert_fingerprint = ( "AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA" ) - - def _test_request(pool): - with pytest.raises(MaxRetryError) as cm: - pool.request("GET", "/", retries=0) - assert isinstance(cm.value.reason, SSLError) - - _test_request(https_pool) - https_pool._get_conn() + e = _test_request(https_pool) + assert "Fingerprints did not match." in str(e) # Uneven length https_pool.assert_fingerprint = "AA:A" - _test_request(https_pool) - https_pool._get_conn() + e = _test_request(https_pool) + assert "Fingerprint of invalid length:" in str(e) # Invalid length https_pool.assert_fingerprint = "AA" - _test_request(https_pool) + e = _test_request(https_pool) + assert "Fingerprint of invalid length:" in str(e) def test_verify_none_and_bad_fingerprint(self): with HTTPSConnectionPool( diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index 7cab833777..7c06875439 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -1427,6 +1427,49 @@ def test_load_verify_locations_exception(self): with pytest.raises(SSLError): ssl_wrap_socket(None, ca_certs="/tmp/fake-file") + def test_ssl_custom_validation_failure_terminates(self, tmpdir): + """ + Ensure that the underlying socket is terminated if custom validation fails. + """ + server_closed = Event() + + def is_closed_socket(sock): + try: + sock.settimeout(SHORT_TIMEOUT) # Python 3 + sock.recv(1) # Python 2 + except (OSError, socket.error): + return True + return False + + def socket_handler(listener): + sock = listener.accept()[0] + try: + _ = ssl.wrap_socket( + sock, + server_side=True, + keyfile=DEFAULT_CERTS["keyfile"], + certfile=DEFAULT_CERTS["certfile"], + ca_certs=DEFAULT_CA, + ) + except ssl.SSLError as e: + assert "alert unknown ca" in str(e) + if is_closed_socket(sock): + server_closed.set() + + self._start_server(socket_handler) + + # client uses a different ca + other_ca = trustme.CA() + other_ca_path = str(tmpdir / "ca.pem") + other_ca.cert_pem.write_to_path(other_ca_path) + + with HTTPSConnectionPool( + self.host, self.port, cert_reqs="REQUIRED", ca_certs=other_ca_path + ) as pool: + with pytest.raises(SSLError): + pool.request("GET", "/", retries=False, timeout=LONG_TIMEOUT) + assert server_closed.wait(LONG_TIMEOUT), "The socket was not terminated" + class TestErrorWrapping(SocketDummyServerTestCase): def test_bad_statusline(self):