From 688584d692f3d77f5d6beca6bb97faef15f0977b Mon Sep 17 00:00:00 2001 From: hodbn Date: Thu, 16 Jul 2020 09:05:57 -0700 Subject: [PATCH] Send "http/1.1" ALPN extension during TLS handshake --- dummyserver/handlers.py | 5 +++ dummyserver/server.py | 41 +++++++++++++++++- .../contrib/_securetransport/bindings.py | 7 +++ .../contrib/_securetransport/low_level.py | 43 +++++++++++++++++++ src/urllib3/contrib/pyopenssl.py | 4 ++ src/urllib3/contrib/securetransport.py | 33 ++++++++++++++ src/urllib3/util/__init__.py | 2 + src/urllib3/util/ssl_.py | 7 +++ test/__init__.py | 14 ++++++ test/with_dummyserver/test_https.py | 10 +++++ test/with_dummyserver/test_socketlevel.py | 33 +++++++++++++- 11 files changed, 196 insertions(+), 3 deletions(-) diff --git a/dummyserver/handlers.py b/dummyserver/handlers.py index 696dbab076..f199dee3f9 100644 --- a/dummyserver/handlers.py +++ b/dummyserver/handlers.py @@ -116,6 +116,11 @@ def certificate(self, request): subject = dict((k, v) for (k, v) in [y for z in cert["subject"] for y in z]) return Response(json.dumps(subject)) + def alpn_protocol(self, request): + """Return the selected ALPN protocol.""" + proto = request.connection.stream.socket.selected_alpn_protocol() + return Response(proto.encode("utf8") if proto is not None else u"") + def source_address(self, request): """Return the requester's IP address.""" return Response(request.remote_ip) diff --git a/dummyserver/server.py b/dummyserver/server.py index 68f383534e..f8b4dff5ce 100755 --- a/dummyserver/server.py +++ b/dummyserver/server.py @@ -15,6 +15,7 @@ from datetime import datetime from urllib3.exceptions import HTTPWarning +from urllib3.util import resolve_cert_reqs, resolve_ssl_version, ALPN_PROTOCOLS from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -33,6 +34,7 @@ "keyfile": os.path.join(CERTS_PATH, "server.key"), "cert_reqs": ssl.CERT_OPTIONAL, "ca_certs": os.path.join(CERTS_PATH, "cacert.pem"), + "alpn_protocols": ALPN_PROTOCOLS, } DEFAULT_CA = os.path.join(CERTS_PATH, "cacert.pem") DEFAULT_CA_KEY = os.path.join(CERTS_PATH, "cacert.key") @@ -133,6 +135,39 @@ def run(self): self.server = self._start_server() +def ssl_options_to_context( + keyfile=None, + certfile=None, + server_side=None, + cert_reqs=None, + ssl_version=None, + ca_certs=None, + do_handshake_on_connect=None, + suppress_ragged_eofs=None, + ciphers=None, + alpn_protocols=None, +): + """Return an equivalent SSLContext based on ssl.wrap_socket args.""" + ssl_version = resolve_ssl_version(ssl_version) + cert_none = resolve_cert_reqs("CERT_NONE") + if cert_reqs is None: + cert_reqs = cert_none + else: + cert_reqs = resolve_cert_reqs(cert_reqs) + + ctx = ssl.SSLContext(ssl_version) + ctx.load_cert_chain(certfile, keyfile) + ctx.verify_mode = cert_reqs + if ctx.verify_mode != cert_none: + ctx.load_verify_locations(cafile=ca_certs) + if alpn_protocols and hasattr(ctx, "set_alpn_protocols"): + try: + ctx.set_alpn_protocols(alpn_protocols) + except NotImplementedError: + pass + return ctx + + def run_tornado_app(app, io_loop, certs, scheme, host): assert io_loop == tornado.ioloop.IOLoop.current() @@ -141,7 +176,11 @@ def run_tornado_app(app, io_loop, certs, scheme, host): app.last_req = datetime(1970, 1, 1) if scheme == "https": - http_server = tornado.httpserver.HTTPServer(app, ssl_options=certs) + if sys.version_info < (2, 7, 9): + ssl_opts = certs + else: + ssl_opts = ssl_options_to_context(**certs) + http_server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_opts) else: http_server = tornado.httpserver.HTTPServer(app) diff --git a/src/urllib3/contrib/_securetransport/bindings.py b/src/urllib3/contrib/_securetransport/bindings.py index d9b6733318..95c54a627b 100644 --- a/src/urllib3/contrib/_securetransport/bindings.py +++ b/src/urllib3/contrib/_securetransport/bindings.py @@ -276,6 +276,13 @@ Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol] Security.SSLSetProtocolVersionMax.restype = OSStatus + try: + Security.SSLSetALPNProtocols.argtypes = [SSLContextRef, CFArrayRef] + Security.SSLSetALPNProtocols.restype = OSStatus + except AttributeError: + # Supported only in 10.12+ + pass + Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p] Security.SecCopyErrorMessageString.restype = CFStringRef diff --git a/src/urllib3/contrib/_securetransport/low_level.py b/src/urllib3/contrib/_securetransport/low_level.py index e60168cac1..d3222e3954 100644 --- a/src/urllib3/contrib/_securetransport/low_level.py +++ b/src/urllib3/contrib/_securetransport/low_level.py @@ -56,6 +56,49 @@ def _cf_dictionary_from_tuples(tuples): ) +def _cfstr(py_bstr): + """ + Given a Python binary data, create a CFString. + The string must be CFReleased by the caller. + """ + c_str = ctypes.c_char_p(py_bstr) + cf_str = CoreFoundation.CFStringCreateWithCString( + CoreFoundation.kCFAllocatorDefault, c_str, CFConst.kCFStringEncodingUTF8, + ) + return cf_str + + +def _create_cfstring_array(lst): + """ + Given a list of Python binary data, create an associated CFMutableArray. + The array must be CFReleased by the caller. + + Raises an ssl.SSLError on failure. + """ + cf_arr = None + try: + cf_arr = CoreFoundation.CFArrayCreateMutable( + CoreFoundation.kCFAllocatorDefault, + 0, + ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks), + ) + if not cf_arr: + raise MemoryError("Unable to allocate memory!") + for item in lst: + cf_str = _cfstr(item) + if not cf_str: + raise MemoryError("Unable to allocate memory!") + try: + CoreFoundation.CFArrayAppendValue(cf_arr, cf_str) + finally: + CoreFoundation.CFRelease(cf_str) + except BaseException as e: + if cf_arr: + CoreFoundation.CFRelease(cf_arr) + raise ssl.SSLError("Unable to allocate array: %s" % (e,)) + return cf_arr + + def _cf_string_to_unicode(value): """ Creates a Unicode string from a CFString object. Used entirely for error diff --git a/src/urllib3/contrib/pyopenssl.py b/src/urllib3/contrib/pyopenssl.py index 81a80651d4..43ea99677e 100644 --- a/src/urllib3/contrib/pyopenssl.py +++ b/src/urllib3/contrib/pyopenssl.py @@ -465,6 +465,10 @@ def load_cert_chain(self, certfile, keyfile=None, password=None): self._ctx.set_passwd_cb(lambda *_: password) self._ctx.use_privatekey_file(keyfile or certfile) + def set_alpn_protocols(self, protocols): + protocols = [six.ensure_binary(p) for p in protocols] + return self._ctx.set_alpn_protos(protocols) + def wrap_socket( self, sock, diff --git a/src/urllib3/contrib/securetransport.py b/src/urllib3/contrib/securetransport.py index a6b7e94ade..f10dfc3749 100644 --- a/src/urllib3/contrib/securetransport.py +++ b/src/urllib3/contrib/securetransport.py @@ -56,6 +56,7 @@ import errno import os.path import shutil +import six import socket import ssl import threading @@ -68,6 +69,7 @@ _cert_array_from_pem, _temporary_keychain, _load_client_cert_chain, + _create_cfstring_array, ) try: # Platform-specific: Python 2 @@ -374,6 +376,19 @@ def _set_ciphers(self): ) _assert_no_error(result) + def _set_alpn_protocols(self, protocols): + """ + Sets up the ALPN protocols on the context. + """ + if not protocols: + return + protocols_arr = _create_cfstring_array(protocols) + try: + result = Security.SSLSetALPNProtocols(self.context, protocols_arr) + _assert_no_error(result) + finally: + CoreFoundation.CFRelease(protocols_arr) + def _custom_validate(self, verify, trust_bundle): """ Called when we have set custom validation. We do this in two cases: @@ -441,6 +456,7 @@ def handshake( client_cert, client_key, client_key_passphrase, + alpn_protocols, ): """ Actually performs the TLS handshake. This is run automatically by @@ -481,6 +497,9 @@ def handshake( # Setup the ciphers. self._set_ciphers() + # Setup the ALPN protocols. + self._set_alpn_protocols(alpn_protocols) + # Set the minimum and maximum TLS versions. result = Security.SSLSetProtocolVersionMin(self.context, min_version) _assert_no_error(result) @@ -754,6 +773,7 @@ def __init__(self, protocol): self._client_cert = None self._client_key = None self._client_key_passphrase = None + self._alpn_protocols = None @property def check_hostname(self): @@ -831,6 +851,18 @@ def load_cert_chain(self, certfile, keyfile=None, password=None): self._client_key = keyfile self._client_cert_passphrase = password + def set_alpn_protocols(self, protocols): + """ + Sets the ALPN protocols that will later be set on the context. + + Raises a NotImplementedError if ALPN is not supported. + """ + if not hasattr(Security, "SSLSetALPNProtocols"): + raise NotImplementedError( + "SecureTransport supports ALPN only in macOS 10.12+" + ) + self._alpn_protocols = [six.ensure_binary(p) for p in protocols] + def wrap_socket( self, sock, @@ -860,5 +892,6 @@ def wrap_socket( self._client_cert, self._client_key, self._client_key_passphrase, + self._alpn_protocols, ) return wrapped_socket diff --git a/src/urllib3/util/__init__.py b/src/urllib3/util/__init__.py index 3fa98c5355..24c16a2894 100644 --- a/src/urllib3/util/__init__.py +++ b/src/urllib3/util/__init__.py @@ -14,6 +14,7 @@ resolve_ssl_version, ssl_wrap_socket, PROTOCOL_TLS, + ALPN_PROTOCOLS, ) from .timeout import current_time, Timeout @@ -27,6 +28,7 @@ "IS_SECURETRANSPORT", "SSLContext", "PROTOCOL_TLS", + "ALPN_PROTOCOLS", "Retry", "Timeout", "Url", diff --git a/src/urllib3/util/ssl_.py b/src/urllib3/util/ssl_.py index 3d89a56c08..9a8ccdaad2 100644 --- a/src/urllib3/util/ssl_.py +++ b/src/urllib3/util/ssl_.py @@ -17,6 +17,7 @@ HAS_SNI = False IS_PYOPENSSL = False IS_SECURETRANSPORT = False +ALPN_PROTOCOLS = ["http/1.1"] # Maps the length of a digest to a possible hash function producing this digest HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256} @@ -373,6 +374,12 @@ def ssl_wrap_socket( else: context.load_cert_chain(certfile, keyfile, key_password) + try: + if hasattr(context, "set_alpn_protocols"): + context.set_alpn_protocols(ALPN_PROTOCOLS) + except NotImplementedError: + pass + # If we detect server_hostname is an IP address then the SNI # extension should not be used according to RFC3546 Section 3.1 # We shouldn't warn the user if SNI isn't available but we would diff --git a/test/__init__.py b/test/__init__.py index 01f02738d0..589473211d 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ from urllib3.exceptions import HTTPWarning from urllib3.packages import six from urllib3.util import ssl_ +from urllib3 import util # We need a host that will not immediately close the connection with a TCP # Reset. @@ -56,6 +57,19 @@ def _can_resolve(host): return False +def has_alpn(ctx_cls=None): + """ Detect if ALPN support is enabled. """ + ctx_cls = ctx_cls or util.SSLContext + ctx = ctx_cls(protocol=ssl_.PROTOCOL_TLS) + try: + if hasattr(ctx, "set_alpn_protocols"): + ctx.set_alpn_protocols(ssl_.ALPN_PROTOCOLS) + return True + except NotImplementedError: + pass + return False + + # Some systems might not resolve "localhost." correctly. # See https://github.com/urllib3/urllib3/issues/1809 and # https://github.com/urllib3/urllib3/pull/1475#issuecomment-440788064. diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 414aea49f0..95aef3eb9f 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -49,6 +49,7 @@ from urllib3.packages import six from urllib3.util.timeout import Timeout import urllib3.util as util +from .. import has_alpn # Retry failed tests pytestmark = pytest.mark.flaky @@ -717,6 +718,15 @@ def test_sslkeylogfile(self, tmpdir, monkeypatch): % str(keylog_file) ) + def test_alpn_default(self): + """Default ALPN protocols are sent by default.""" + if not has_alpn() or not has_alpn(ssl.SSLContext): + pytest.skip("ALPN-support not available") + with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool: + r = pool.request("GET", "/alpn_protocol", retries=0) + assert r.status == 200 + assert r.data.decode("utf-8") == util.ALPN_PROTOCOLS[0] + @requiresTLSv1() class TestHTTPS_TLSv1(TestHTTPS): diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index e9d95ff4e4..9f7821a2a9 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -26,7 +26,7 @@ encrypt_key_pem, ) -from .. import onlyPy3, LogRecorder +from .. import onlyPy3, LogRecorder, has_alpn try: from mimetools import Message as MimeToolMessage @@ -102,7 +102,7 @@ def socket_handler(listener): sock.close() self._start_server(socket_handler) - with HTTPConnectionPool(self.host, self.port) as pool: + with HTTPSConnectionPool(self.host, self.port) as pool: try: pool.request("GET", "/", retries=0) except MaxRetryError: # We are violating the protocol @@ -114,6 +114,35 @@ def socket_handler(listener): ), "missing hostname in SSL handshake" +class TestALPN(SocketDummyServerTestCase): + def test_alpn_protocol_in_first_request_packet(self): + if not has_alpn(): + pytest.skip("ALPN-support not available") + + done_receiving = Event() + self.buf = b"" + + def socket_handler(listener): + sock = listener.accept()[0] + + self.buf = sock.recv(65536) # We only accept one packet + done_receiving.set() # let the test know it can proceed + sock.close() + + self._start_server(socket_handler) + with HTTPSConnectionPool(self.host, self.port) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: # We are violating the protocol + pass + successful = done_receiving.wait(LONG_TIMEOUT) + assert successful, "Timed out waiting for connection accept" + for protocol in util.ALPN_PROTOCOLS: + assert ( + protocol.encode("ascii") in self.buf + ), "missing ALPN protocol in SSL handshake" + + class TestClientCerts(SocketDummyServerTestCase): """ Tests for client certificate support.