Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Terminate connection when custom verification fails (SecureTransport) #1977

Merged
merged 11 commits into from Nov 1, 2020
24 changes: 24 additions & 0 deletions src/urllib3/contrib/_securetransport/low_level.py
Expand Up @@ -13,6 +13,7 @@
import os
import re
import ssl
import struct
import tempfile

from .bindings import CFConst, CoreFoundation, Security
Expand Down Expand Up @@ -370,3 +371,26 @@ def _load_client_cert_chain(keychain, *paths):
finally:
for obj in itertools.chain(identities, certificates):
CoreFoundation.CFRelease(obj)


TLS_PROTOCOL_VERSIONS = {
"SSLv2": (0, 2),
"SSLv3": (3, 0),
"TLSv1": (3, 1),
"TLSv1.1": (3, 2),
"TLSv1.2": (3, 3),
}


def _build_tls_unknown_ca_alert(version):
"""
Builds a TLS alert record for an unknown CA.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some useful documentation or even StackOverflow link for this?

ver_maj, ver_min = TLS_PROTOCOL_VERSIONS[version]
severity_fatal = 0x02
description_unknown_ca = 0x30
msg = struct.pack(">BB", severity_fatal, description_unknown_ca)
msg_len = len(msg)
record_type_alert = 0x15
record = struct.pack(">BBBH", record_type_alert, ver_maj, ver_min, msg_len) + msg
return record
38 changes: 29 additions & 9 deletions src/urllib3/contrib/securetransport.py
Expand Up @@ -60,6 +60,7 @@
import shutil
import socket
import ssl
import struct
import threading
import weakref

Expand All @@ -69,6 +70,7 @@
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
from ._securetransport.low_level import (
_assert_no_error,
_build_tls_unknown_ca_alert,
_cert_array_from_pem,
_create_cfstring_array,
_load_client_cert_chain,
Expand Down Expand Up @@ -397,11 +399,37 @@ def _custom_validate(self, verify, trust_bundle):
Called when we have set custom validation. We do this in two cases:
first, when cert validation is entirely disabled; and second, when
using a custom trust DB.
Raises an SSLError if the connection is not trusted.
"""
# If we disabled cert validation, just say: cool.
if not verify:
return

successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
try:
trust_result = self._evaluate_trust(trust_bundle)
if trust_result in successes:
return
reason = "error code: %d" % (trust_result,)
except Exception as e:
# Do not trust on error
reason = "exception: %r" % (e,)

# SecureTransport does not send an alert nor shuts down the connection.
rec = _build_tls_unknown_ca_alert(self.version())
self.socket.sendall(rec)
# close the connection immediately
# l_onoff = 1, activate linger
# l_linger = 0, linger for 0 seoncds
opts = struct.pack("ii", 1, 0)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, opts)
hodbn marked this conversation as resolved.
Show resolved Hide resolved
self.close()
raise ssl.SSLError("certificate verify failed, %s" % reason)

def _evaluate_trust(self, trust_bundle):
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, "rb") as f:
Expand Down Expand Up @@ -439,15 +467,7 @@ def _custom_validate(self, verify, trust_bundle):
if cert_array is not None:
CoreFoundation.CFRelease(cert_array)

# Ok, now we can look at what the result was.
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed,
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" % trust_result.value
)
return trust_result.value

def handshake(
self,
Expand Down
26 changes: 14 additions & 12 deletions test/with_dummyserver/test_https.py
Expand Up @@ -477,29 +477,31 @@ def test_assert_fingerprint_sha256(self):
https_pool.request("GET", "/")

def test_assert_invalid_fingerprint(self):
def _test_request(pool):
with pytest.raises(MaxRetryError) as cm:
pool.request("GET", "/", retries=0)
assert isinstance(cm.value.reason, SSLError)
return cm.value.reason

with HTTPSConnectionPool(
"127.0.0.1", self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA
self.host, self.port, cert_reqs="CERT_REQUIRED", ca_certs=DEFAULT_CA
) as https_pool:

https_pool.assert_fingerprint = (
"AA:AA:AA:AA:AA:AAAA:AA:AAAA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA"
)

def _test_request(pool):
with pytest.raises(MaxRetryError) as cm:
pool.request("GET", "/", retries=0)
assert isinstance(cm.value.reason, SSLError)

_test_request(https_pool)
https_pool._get_conn()
e = _test_request(https_pool)
assert "Fingerprints did not match." in str(e)

# Uneven length
https_pool.assert_fingerprint = "AA:A"
_test_request(https_pool)
https_pool._get_conn()
e = _test_request(https_pool)
assert "Fingerprint of invalid length:" in str(e)

# Invalid length
https_pool.assert_fingerprint = "AA"
_test_request(https_pool)
e = _test_request(https_pool)
assert "Fingerprint of invalid length:" in str(e)

def test_verify_none_and_bad_fingerprint(self):
with HTTPSConnectionPool(
Expand Down
43 changes: 43 additions & 0 deletions test/with_dummyserver/test_socketlevel.py
Expand Up @@ -1427,6 +1427,49 @@ def test_load_verify_locations_exception(self):
with pytest.raises(SSLError):
ssl_wrap_socket(None, ca_certs="/tmp/fake-file")

def test_ssl_custom_validation_failure_terminates(self, tmpdir):
"""
Ensure that the underlying socket is terminated if custom validation fails.
"""
server_closed = Event()

def is_closed_socket(sock):
try:
sock.settimeout(SHORT_TIMEOUT) # Python 3
sock.recv(1) # Python 2
except (OSError, socket.error):
return True
return False

def socket_handler(listener):
sock = listener.accept()[0]
try:
_ = ssl.wrap_socket(
sock,
server_side=True,
keyfile=DEFAULT_CERTS["keyfile"],
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
)
except ssl.SSLError as e:
assert "alert unknown ca" in str(e)
if is_closed_socket(sock):
server_closed.set()

self._start_server(socket_handler)

# client uses a different ca
other_ca = trustme.CA()
other_ca_path = str(tmpdir / "ca.pem")
other_ca.cert_pem.write_to_path(other_ca_path)

with HTTPSConnectionPool(
self.host, self.port, cert_reqs="REQUIRED", ca_certs=other_ca_path
) as pool:
with pytest.raises(SSLError):
pool.request("GET", "/", retries=False, timeout=LONG_TIMEOUT)
assert server_closed.wait(LONG_TIMEOUT), "The socket was not terminated"


class TestErrorWrapping(SocketDummyServerTestCase):
def test_bad_statusline(self):
Expand Down