Skip to content

Commit

Permalink
Terminate connection when custom verification fails (SecureTransport) (
Browse files Browse the repository at this point in the history
  • Loading branch information
hodbn committed Nov 1, 2020
1 parent 16b7b33 commit 5eb604f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 21 deletions.
24 changes: 24 additions & 0 deletions src/urllib3/contrib/_securetransport/low_level.py
Expand Up @@ -13,6 +13,7 @@
import os
import re
import ssl
import struct
import tempfile

from .bindings import CFConst, CoreFoundation, Security
Expand Down Expand Up @@ -370,3 +371,26 @@ def _load_client_cert_chain(keychain, *paths):
finally:
for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj)


TLS_PROTOCOL_VERSIONS = {
"SSLv2": (0, 2),
"SSLv3": (3, 0),
"TLSv1": (3, 1),
"TLSv1.1": (3, 2),
"TLSv1.2": (3, 3),
}


def _build_tls_unknown_ca_alert(version):
"""
Builds a TLS alert record for an unknown CA.
"""
ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
severity_fatal = 0x02
description_unknown_ca = 0x30
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
msg_len = len(msg)
record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record
38 changes: 29 additions & 9 deletions src/urllib3/contrib/securetransport.py
Expand Up @@ -60,6 +60,7 @@
import shutil
import socket
import ssl
import struct
import threading
import weakref

Expand All @@ -69,6 +70,7 @@
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
from ._securetransport.low_level import (
_assert_no_error,
_build_tls_unknown_ca_alert,
_cert_array_from_pem,
_create_cfstring_array,
_load_client_cert_chain,
Expand Down Expand Up @@ -397,11 +399,37 @@ def _custom_validate(self, verify, trust_bundle):
Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when
using a custom trust DB.
Raises an SSLError if the connection is not trusted.
"""
# If we disabled cert validation, just say: cool.
if not verify:
return

successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
try:
trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes:
return
reason = "error code: %d" % (trust_result,)
except Exception as e:
# Do not trust on error
reason = "exception: %r" % (e,)

# SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version())
self.socket.sendall(rec)
# close the connection immediately
# l_onoff = 1, activate linger
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
self.close()
raise ssl.SSLError("certificate verify failed, %s" % reason)

def _evaluate_trust(self, trust_bundle):
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f:
Expand Down Expand Up @@ -439,15 +467,7 @@ def _custom_validate(self, verify, trust_bundle):
if cert_array is not None:
CoreFoundation.CFRelease(cert_array)

# Ok, now we can look at what the result was.
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" % trust_result.value
)
return trust_result.value

def handshake(
self,
Expand Down
26 changes: 14 additions & 12 deletions test/with_dummyserver/test_https.py
Expand Up @@ -477,29 +477,31 @@ def test_assert_fingerprint_sha256(self):
https_pool.request("GET", "/")

def test_assert_invalid_fingerprint(self):
def _test_request(pool):
with pytest.raises(MaxRetryError) as cm:
pool.request("GET", "/", retries=0)
assert isinstance(cm.value.reason, SSLError)
return cm.value.reason

with HTTPSConnectionPool(
"127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA
self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA
) as https_pool:

https_pool.assert_fingerprint = (
"AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA"
)

def _test_request(pool):
with pytest.raises(MaxRetryError) as cm:
pool.request("GET", "/", retries=0)
assert isinstance(cm.value.reason, SSLError)

_test_request(https_pool)
https_pool._get_conn()
e = _test_request(https_pool)
assert "Fingerprints did not match." in str(e)

# Uneven length
https_pool.assert_fingerprint = "AA:A"
_test_request(https_pool)
https_pool._get_conn()
e = _test_request(https_pool)
assert "Fingerprint of invalid length:" in str(e)

# Invalid length
https_pool.assert_fingerprint = "AA"
_test_request(https_pool)
e = _test_request(https_pool)
assert "Fingerprint of invalid length:" in str(e)

def test_verify_none_and_bad_fingerprint(self):
with HTTPSConnectionPool(
Expand Down
43 changes: 43 additions & 0 deletions test/with_dummyserver/test_socketlevel.py
Expand Up @@ -1427,6 +1427,49 @@ def test_load_verify_locations_exception(self):
with pytest.raises(SSLError):
ssl_wrap_socket(None, ca_certs="/tmp/fake-file")

def test_ssl_custom_validation_failure_terminates(self, tmpdir):
"""
Ensure that the underlying socket is terminated if custom validation fails.
"""
server_closed = Event()

def is_closed_socket(sock):
try:
sock.settimeout(SHORT_TIMEOUT) # Python 3
sock.recv(1) # Python 2
except (OSError, socket.error):
return True
return False

def socket_handler(listener):
sock = listener.accept()[0]
try:
_ = ssl.wrap_socket(
sock,
server_side=True,
keyfile=DEFAULT_CERTS["keyfile"],
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
)
except ssl.SSLError as e:
assert "alert unknown ca" in str(e)
if is_closed_socket(sock):
server_closed.set()

self._start_server(socket_handler)

# client uses a different ca
other_ca = trustme.CA()
other_ca_path = str(tmpdir / "ca.pem")
other_ca.cert_pem.write_to_path(other_ca_path)

with HTTPSConnectionPool(
self.host, self.port, cert_reqs="REQUIRED", ca_certs=other_ca_path
) as pool:
with pytest.raises(SSLError):
pool.request("GET", "/", retries=False, timeout=LONG_TIMEOUT)
assert server_closed.wait(LONG_TIMEOUT), "The socket was not terminated"


class TestErrorWrapping(SocketDummyServerTestCase):
def test_bad_statusline(self):
Expand Down

0 comments on commit 5eb604f

Please sign in to comment.