diff --git a/dummyserver/handlers.py b/dummyserver/handlers.py index 696dbab076..f199dee3f9 100644 --- a/dummyserver/handlers.py +++ b/dummyserver/handlers.py @@ -116,6 +116,11 @@ def certificate(self, request): subject = dict((k, v) for (k, v) in [y for z in cert["subject"] for y in z]) return Response(json.dumps(subject)) + def alpn_protocol(self, request): + """Return the selected ALPN protocol.""" + proto = request.connection.stream.socket.selected_alpn_protocol() + return Response(proto.encode("utf8") if proto is not None else u"") + def source_address(self, request): """Return the requester's IP address.""" return Response(request.remote_ip) diff --git a/dummyserver/server.py b/dummyserver/server.py index 68f383534e..7564120318 100755 --- a/dummyserver/server.py +++ b/dummyserver/server.py @@ -15,6 +15,7 @@ from datetime import datetime from urllib3.exceptions import HTTPWarning +from urllib3.util import resolve_cert_reqs, resolve_ssl_version from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -133,6 +134,36 @@ def run(self): self.server = self._start_server() +def ssl_options_to_context( + keyfile=None, + certfile=None, + server_side=None, + cert_reqs=None, + ssl_version=None, + ca_certs=None, + do_handshake_on_connect=None, + suppress_ragged_eofs=None, + ciphers=None, + alpn_protocols=None, +): + """Return an equivalent SSLContext based on ssl.wrap_socket args.""" + ssl_version = resolve_ssl_version(ssl_version) + cert_none = resolve_cert_reqs("CERT_NONE") + if cert_reqs is None: + cert_reqs = cert_none + else: + cert_reqs = resolve_cert_reqs(cert_reqs) + + ctx = ssl.SSLContext(ssl_version) + ctx.load_cert_chain(certfile, keyfile) + ctx.verify_mode = cert_reqs + if ctx.verify_mode != cert_none: + ctx.load_verify_locations(cafile=ca_certs) + if alpn_protocols: + ctx.set_alpn_protocols(alpn_protocols) + return ctx + + def run_tornado_app(app, io_loop, certs, scheme, host): assert io_loop == tornado.ioloop.IOLoop.current() @@ -141,7 +172,11 @@ def run_tornado_app(app, io_loop, certs, scheme, host): app.last_req = datetime(1970, 1, 1) if scheme == "https": - http_server = tornado.httpserver.HTTPServer(app, ssl_options=certs) + if sys.version_info < (2, 7, 9): + ssl_opts = certs + else: + ssl_opts = ssl_options_to_context(**certs) + http_server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_opts) else: http_server = tornado.httpserver.HTTPServer(app) diff --git a/src/urllib3/connection.py b/src/urllib3/connection.py index ce94b25640..337d81b1c0 100644 --- a/src/urllib3/connection.py +++ b/src/urllib3/connection.py @@ -283,6 +283,7 @@ def set_cert( assert_fingerprint=None, ca_cert_dir=None, ca_cert_data=None, + alpn_protocols=None, ): """ This method should only be called once, before the connection is used. @@ -304,6 +305,7 @@ def set_cert( self.ca_certs = ca_certs and os.path.expanduser(ca_certs) self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) self.ca_cert_data = ca_cert_data + self.alpn_protocols = alpn_protocols def connect(self): # Add certificate verification @@ -370,6 +372,7 @@ def connect(self): ca_cert_data=self.ca_cert_data, server_hostname=server_hostname, ssl_context=context, + alpn_protocols=self.alpn_protocols, ) if self.assert_fingerprint: diff --git a/src/urllib3/connectionpool.py b/src/urllib3/connectionpool.py index 492590fb9e..05e80357cb 100644 --- a/src/urllib3/connectionpool.py +++ b/src/urllib3/connectionpool.py @@ -846,7 +846,7 @@ class HTTPSConnectionPool(HTTPConnectionPool): If ``assert_hostname`` is False, no verification is done. The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``, - ``ca_cert_dir``, ``ssl_version``, ``key_password`` are only used if :mod:`ssl` + ``ca_cert_dir``, ``ssl_version``, ``key_password``, ``alpn_protocols`` are only used if :mod:`ssl` is available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade the connection socket into an SSL socket. """ @@ -875,6 +875,7 @@ def __init__( assert_hostname=None, assert_fingerprint=None, ca_cert_dir=None, + alpn_protocols=None, **conn_kw ): @@ -902,6 +903,7 @@ def __init__( self.ssl_version = ssl_version self.assert_hostname = assert_hostname self.assert_fingerprint = assert_fingerprint + self.alpn_protocols = alpn_protocols def _prepare_conn(self, conn): """ @@ -919,6 +921,7 @@ def _prepare_conn(self, conn): ca_cert_dir=self.ca_cert_dir, assert_hostname=self.assert_hostname, assert_fingerprint=self.assert_fingerprint, + alpn_protocols=self.alpn_protocols, ) conn.ssl_version = self.ssl_version return conn diff --git a/src/urllib3/contrib/pyopenssl.py b/src/urllib3/contrib/pyopenssl.py index 81a80651d4..0c480b6b87 100644 --- a/src/urllib3/contrib/pyopenssl.py +++ b/src/urllib3/contrib/pyopenssl.py @@ -78,6 +78,8 @@ class UnsupportedExtension(Exception): # SNI always works. HAS_SNI = True +HAS_ALPN = hasattr(OpenSSL.SSL.Context, "set_alpn_protos") + # Map from urllib3 to PyOpenSSL compatible parameter-values. _openssl_versions = { util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, @@ -106,6 +108,7 @@ class UnsupportedExtension(Exception): SSL_WRITE_BLOCKSIZE = 16384 orig_util_HAS_SNI = util.HAS_SNI +orig_util_HAS_ALPN = util.HAS_ALPN orig_util_SSLContext = util.ssl_.SSLContext @@ -121,6 +124,8 @@ def inject_into_urllib3(): util.ssl_.SSLContext = PyOpenSSLContext util.HAS_SNI = HAS_SNI util.ssl_.HAS_SNI = HAS_SNI + util.HAS_ALPN = HAS_ALPN + util.ssl_.HAS_ALPN = HAS_ALPN util.IS_PYOPENSSL = True util.ssl_.IS_PYOPENSSL = True @@ -132,6 +137,8 @@ def extract_from_urllib3(): util.ssl_.SSLContext = orig_util_SSLContext util.HAS_SNI = orig_util_HAS_SNI util.ssl_.HAS_SNI = orig_util_HAS_SNI + util.HAS_ALPN = orig_util_HAS_ALPN + util.ssl_.HAS_ALPN = orig_util_HAS_ALPN util.IS_PYOPENSSL = False util.ssl_.IS_PYOPENSSL = False @@ -465,6 +472,10 @@ def load_cert_chain(self, certfile, keyfile=None, password=None): self._ctx.set_passwd_cb(lambda *_: password) self._ctx.use_privatekey_file(keyfile or certfile) + def set_alpn_protocols(self, protocols): + protocols = [six.ensure_binary(p) for p in protocols] + return self._ctx.set_alpn_protos(protocols) + def wrap_socket( self, sock, diff --git a/src/urllib3/contrib/securetransport.py b/src/urllib3/contrib/securetransport.py index a6b7e94ade..5bb12c7bba 100644 --- a/src/urllib3/contrib/securetransport.py +++ b/src/urllib3/contrib/securetransport.py @@ -81,7 +81,12 @@ # SNI always works HAS_SNI = True +# TODO: ALPN is currently not implemented. +# See https://developer.apple.com/documentation/security/2976269-sec_protocol_options_add_tls_app +HAS_ALPN = False + orig_util_HAS_SNI = util.HAS_SNI +orig_util_HAS_ALPN = util.HAS_ALPN orig_util_SSLContext = util.ssl_.SSLContext # This dictionary is used by the read callback to obtain a handle to the @@ -185,6 +190,8 @@ def inject_into_urllib3(): util.ssl_.SSLContext = SecureTransportContext util.HAS_SNI = HAS_SNI util.ssl_.HAS_SNI = HAS_SNI + util.HAS_ALPN = HAS_ALPN + util.ssl_.HAS_ALPN = HAS_ALPN util.IS_SECURETRANSPORT = True util.ssl_.IS_SECURETRANSPORT = True @@ -197,6 +204,8 @@ def extract_from_urllib3(): util.ssl_.SSLContext = orig_util_SSLContext util.HAS_SNI = orig_util_HAS_SNI util.ssl_.HAS_SNI = orig_util_HAS_SNI + util.HAS_ALPN = orig_util_HAS_ALPN + util.ssl_.HAS_ALPN = orig_util_HAS_ALPN util.IS_SECURETRANSPORT = False util.ssl_.IS_SECURETRANSPORT = False diff --git a/src/urllib3/poolmanager.py b/src/urllib3/poolmanager.py index a0e5b974b9..fc21a2655d 100644 --- a/src/urllib3/poolmanager.py +++ b/src/urllib3/poolmanager.py @@ -42,6 +42,7 @@ class InvalidProxyConfigurationWarning(HTTPWarning): "ca_cert_dir", "ssl_context", "key_password", + "alpn_protocols", ) # All known keyword arguments that could be provided to the pool manager, its @@ -72,6 +73,7 @@ class InvalidProxyConfigurationWarning(HTTPWarning): "key_assert_hostname", # bool or string "key_assert_fingerprint", # str "key_server_hostname", # str + "key_alpn_protocols", # list of str ) #: The namedtuple class used to construct keys for the connection pool. diff --git a/src/urllib3/util/__init__.py b/src/urllib3/util/__init__.py index a96c73a9d8..6d197076cb 100644 --- a/src/urllib3/util/__init__.py +++ b/src/urllib3/util/__init__.py @@ -7,6 +7,7 @@ from .ssl_ import ( SSLContext, HAS_SNI, + HAS_ALPN, IS_PYOPENSSL, IS_SECURETRANSPORT, assert_fingerprint, @@ -14,6 +15,8 @@ resolve_ssl_version, ssl_wrap_socket, PROTOCOL_TLS, + DEFAULT_ALPN_PROTOCOLS, + SUPPRESS_ALPN, ) from .timeout import current_time, Timeout @@ -23,10 +26,13 @@ __all__ = ( "HAS_SNI", + "HAS_ALPN", "IS_PYOPENSSL", "IS_SECURETRANSPORT", "SSLContext", "PROTOCOL_TLS", + "DEFAULT_ALPN_PROTOCOLS", + "SUPPRESS_ALPN", "Retry", "Timeout", "Url", diff --git a/src/urllib3/util/ssl_.py b/src/urllib3/util/ssl_.py index 3d89a56c08..0dde5c1af5 100644 --- a/src/urllib3/util/ssl_.py +++ b/src/urllib3/util/ssl_.py @@ -15,8 +15,12 @@ SSLContext = None HAS_SNI = False +HAS_ALPN = False IS_PYOPENSSL = False IS_SECURETRANSPORT = False +DEFAULT_ALPN_PROTOCOLS = ["http/1.1"] +#: A sentinel object to suppress the default ALPN protcols +SUPPRESS_ALPN = object() # Maps the length of a digest to a possible hash function producing this digest HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256} @@ -41,6 +45,7 @@ def _const_compare_digest_backport(a, b): import ssl from ssl import wrap_socket, CERT_REQUIRED from ssl import HAS_SNI # Has SNI? + from ssl import HAS_ALPN # Has ALPN? except ImportError: pass @@ -316,6 +321,7 @@ def ssl_wrap_socket( ca_cert_dir=None, key_password=None, ca_cert_data=None, + alpn_protocols=None, ): """ All arguments except for server_hostname, ssl_context, and ca_cert_dir have @@ -337,6 +343,8 @@ def ssl_wrap_socket( :param ca_cert_data: Optional string containing CA certificates in PEM format suitable for passing as the cadata parameter to SSLContext.load_verify_locations() + :param alpn_protocols: + When ALPN is supported, the ALPN protocols to negotiate. :data:`SUPPRESS_ALPN` will suppress sending :data:`DEFAULT_ALPN_PROTOCOLS`. """ context = ssl_context if context is None: @@ -373,6 +381,11 @@ def ssl_wrap_socket( else: context.load_cert_chain(certfile, keyfile, key_password) + if HAS_ALPN and alpn_protocols is not SUPPRESS_ALPN: + if alpn_protocols is None: + alpn_protocols = DEFAULT_ALPN_PROTOCOLS + context.set_alpn_protocols(alpn_protocols) + # If we detect server_hostname is an IP address then the SNI # extension should not be used according to RFC3546 Section 3.1 # We shouldn't warn the user if SNI isn't available but we would diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 414aea49f0..3c4e5f493b 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -805,3 +805,44 @@ def test_can_validate_ipv6_san(self, ipv6_san_server): ) as https_pool: r = https_pool.request("GET", "/") assert r.status == 200 + + +class TestHTTPS_ALPN(TestHTTPS): + servers_last = "secondproto" + alpn_protos = util.DEFAULT_ALPN_PROTOCOLS + [servers_last] + servers_first = alpn_protos[0] + certs = dict(DEFAULT_CERTS, alpn_protocols=alpn_protos) + + def _get_pool(self, **kwargs): + return HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA, **kwargs) + + def test_alpn_custom(self): + """Setting custom ALPN protocols chooses the right protocol.""" + # choose the right protocol (server's first, client's last) + with self._get_pool(alpn_protocols=["fakeproto", self.servers_first]) as pool: + r = pool.request("GET", "/alpn_protocol") + assert r.status == 200 + assert r.data.decode("utf-8") == self.servers_first + # choose the right protocol (client's first, server's last) + with self._get_pool(alpn_protocols=[self.servers_last, "fakeproto"]) as pool: + r = pool.request("GET", "/alpn_protocol") + assert r.status == 200 + assert r.data.decode("utf-8") == self.servers_last + # don't choose a protocol + with self._get_pool(alpn_protocols=["fakeproto"]) as pool: + r = pool.request("GET", "/alpn_protocol", retries=0) + assert r.status == 200 + assert r.data.decode("utf-8") == "" + + def test_alpn_default(self): + """Default ALPN protocols are sent by default, but can be suppressed.""" + # sends default alpn protocols + with self._get_pool() as pool: + r = pool.request("GET", "/alpn_protocol", retries=0) + assert r.status == 200 + assert r.data.decode("utf-8") == util.DEFAULT_ALPN_PROTOCOLS[0] + # can suppress default alpn protocols + with self._get_pool(alpn_protocols=util.SUPPRESS_ALPN) as pool: + r = pool.request("GET", "/alpn_protocol") + assert r.status == 200 + assert r.data.decode("utf-8") == "" diff --git a/test/with_dummyserver/test_socketlevel.py b/test/with_dummyserver/test_socketlevel.py index 76ce241832..961a904119 100644 --- a/test/with_dummyserver/test_socketlevel.py +++ b/test/with_dummyserver/test_socketlevel.py @@ -112,6 +112,32 @@ def socket_handler(listener): self.host.encode("ascii") in self.buf ), "missing hostname in SSL handshake" + def test_alpn_protocol_in_first_request_packet(self): + if not util.HAS_ALPN: + pytest.skip("ALPN-support not available") + done_receiving = Event() + self.buf = b"" + + def socket_handler(listener): + sock = listener.accept()[0] + + self.buf = sock.recv(65536) # We only accept one packet + done_receiving.set() # let the test know it can proceed + sock.close() + + self._start_server(socket_handler) + with HTTPSConnectionPool(self.host, self.port) as pool: + try: + pool.request("GET", "/", retries=0) + except MaxRetryError: # We are violating the protocol + pass + successful = done_receiving.wait(LONG_TIMEOUT) + assert successful, "Timed out waiting for connection accept" + for protocol in util.DEFAULT_ALPN_PROTOCOLS: + assert ( + protocol.encode("ascii") in self.buf + ), "missing ALPN protocol in SSL handshake" + class TestClientCerts(SocketDummyServerTestCase): """