Skip to content

Commit

Permalink
Add ALPN support
Browse files Browse the repository at this point in the history
  • Loading branch information
hodbn committed Jun 25, 2020
1 parent 8e93132 commit b6b41bd
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 2 deletions.
5 changes: 5 additions & 0 deletions dummyserver/handlers.py
Expand Up @@ -116,6 +116,11 @@ def certificate(self, request):
subject = dict((k, v) for (k, v) in [y for z in cert["subject"] for y in z])
return Response(json.dumps(subject))

def alpn_protocol(self, request):
"""Return the selected ALPN protocol."""
proto = request.connection.stream.socket.selected_alpn_protocol()
return Response(proto.encode("utf8") if proto is not None else u"")

def source_address(self, request):
"""Return the requester's IP address."""
return Response(request.remote_ip)
Expand Down
37 changes: 36 additions & 1 deletion dummyserver/server.py
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime

from urllib3.exceptions import HTTPWarning
from urllib3.util import resolve_cert_reqs, resolve_ssl_version

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand Down Expand Up @@ -133,6 +134,36 @@ def run(self):
self.server = self._start_server()


def ssl_options_to_context(
keyfile=None,
certfile=None,
server_side=None,
cert_reqs=None,
ssl_version=None,
ca_certs=None,
do_handshake_on_connect=None,
suppress_ragged_eofs=None,
ciphers=None,
alpn_protocols=None,
):
"""Return an equivalent SSLContext based on ssl.wrap_socket args."""
ssl_version = resolve_ssl_version(ssl_version)
cert_none = resolve_cert_reqs("CERT_NONE")
if cert_reqs is None:
cert_reqs = cert_none
else:
cert_reqs = resolve_cert_reqs(cert_reqs)

ctx = ssl.SSLContext(ssl_version)
ctx.load_cert_chain(certfile, keyfile)
ctx.verify_mode = cert_reqs
if ctx.verify_mode != cert_none:
ctx.load_verify_locations(cafile=ca_certs)
if alpn_protocols:
ctx.set_alpn_protocols(alpn_protocols)
return ctx


def run_tornado_app(app, io_loop, certs, scheme, host):
assert io_loop == tornado.ioloop.IOLoop.current()

Expand All @@ -141,7 +172,11 @@ def run_tornado_app(app, io_loop, certs, scheme, host):
app.last_req = datetime(1970, 1, 1)

if scheme == "https":
http_server = tornado.httpserver.HTTPServer(app, ssl_options=certs)
if sys.version_info < (2, 7, 9):
ssl_opts = certs
else:
ssl_opts = ssl_options_to_context(**certs)
http_server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_opts)
else:
http_server = tornado.httpserver.HTTPServer(app)

Expand Down
3 changes: 3 additions & 0 deletions src/urllib3/connection.py
Expand Up @@ -283,6 +283,7 @@ def set_cert(
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
alpn_protocols=None,
):
"""
This method should only be called once, before the connection is used.
Expand All @@ -304,6 +305,7 @@ def set_cert(
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
self.alpn_protocols = alpn_protocols

def connect(self):
# Add certificate verification
Expand Down Expand Up @@ -370,6 +372,7 @@ def connect(self):
ca_cert_data=self.ca_cert_data,
server_hostname=server_hostname,
ssl_context=context,
alpn_protocols=self.alpn_protocols,
)

if self.assert_fingerprint:
Expand Down
5 changes: 4 additions & 1 deletion src/urllib3/connectionpool.py
Expand Up @@ -846,7 +846,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
If ``assert_hostname`` is False, no verification is done.
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``,
``ca_cert_dir``, ``ssl_version``, ``key_password`` are only used if :mod:`ssl`
``ca_cert_dir``, ``ssl_version``, ``key_password``, ``alpn_protocols`` are only used if :mod:`ssl`
is available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket.
"""
Expand Down Expand Up @@ -875,6 +875,7 @@ def __init__(
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
alpn_protocols=None,
**conn_kw
):

Expand Down Expand Up @@ -902,6 +903,7 @@ def __init__(
self.ssl_version = ssl_version
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
self.alpn_protocols = alpn_protocols

def _prepare_conn(self, conn):
"""
Expand All @@ -919,6 +921,7 @@ def _prepare_conn(self, conn):
ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint,
alpn_protocols=self.alpn_protocols,
)
conn.ssl_version = self.ssl_version
return conn
Expand Down
12 changes: 12 additions & 0 deletions src/urllib3/contrib/pyopenssl.py
Expand Up @@ -78,6 +78,9 @@ class UnsupportedExtension(Exception):
# SNI always works.
HAS_SNI = True

# ALPN always works.
HAS_ALPN = True

# Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = {
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
Expand Down Expand Up @@ -106,6 +109,7 @@ class UnsupportedExtension(Exception):
SSL_WRITE_BLOCKSIZE = 16384

orig_util_HAS_SNI = util.HAS_SNI
orig_util_HAS_ALPN = util.HAS_ALPN
orig_util_SSLContext = util.ssl_.SSLContext


Expand All @@ -121,6 +125,8 @@ def inject_into_urllib3():
util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.HAS_ALPN = HAS_ALPN
util.ssl_.HAS_ALPN = HAS_ALPN
util.IS_PYOPENSSL = True
util.ssl_.IS_PYOPENSSL = True

Expand All @@ -132,6 +138,8 @@ def extract_from_urllib3():
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.HAS_ALPN = orig_util_HAS_ALPN
util.ssl_.HAS_ALPN = orig_util_HAS_ALPN
util.IS_PYOPENSSL = False
util.ssl_.IS_PYOPENSSL = False

Expand Down Expand Up @@ -465,6 +473,10 @@ def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile)

def set_alpn_protocols(self, protocols):
protocols = [six.ensure_binary(p) for p in protocols]
return self._ctx.set_alpn_protos(protocols)

def wrap_socket(
self,
sock,
Expand Down
9 changes: 9 additions & 0 deletions src/urllib3/contrib/securetransport.py
Expand Up @@ -81,7 +81,12 @@
# SNI always works
HAS_SNI = True

# TODO: ALPN is currently not implemented.
# See https://developer.apple.com/documentation/security/2976269-sec_protocol_options_add_tls_app
HAS_ALPN = False

orig_util_HAS_SNI = util.HAS_SNI
orig_util_HAS_ALPN = util.HAS_ALPN
orig_util_SSLContext = util.ssl_.SSLContext

# This dictionary is used by the read callback to obtain a handle to the
Expand Down Expand Up @@ -185,6 +190,8 @@ def inject_into_urllib3():
util.ssl_.SSLContext = SecureTransportContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.HAS_ALPN = HAS_ALPN
util.ssl_.HAS_ALPN = HAS_ALPN
util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True

Expand All @@ -197,6 +204,8 @@ def extract_from_urllib3():
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.HAS_ALPN = orig_util_HAS_ALPN
util.ssl_.HAS_ALPN = orig_util_HAS_ALPN
util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False

Expand Down
2 changes: 2 additions & 0 deletions src/urllib3/poolmanager.py
Expand Up @@ -42,6 +42,7 @@ class InvalidProxyConfigurationWarning(HTTPWarning):
"ca_cert_dir",
"ssl_context",
"key_password",
"alpn_protocols",
)

# All known keyword arguments that could be provided to the pool manager, its
Expand Down Expand Up @@ -72,6 +73,7 @@ class InvalidProxyConfigurationWarning(HTTPWarning):
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
"key_alpn_protocols", # list of str
)

#: The namedtuple class used to construct keys for the connection pool.
Expand Down
6 changes: 6 additions & 0 deletions src/urllib3/util/__init__.py
Expand Up @@ -7,13 +7,16 @@
from .ssl_ import (
SSLContext,
HAS_SNI,
HAS_ALPN,
IS_PYOPENSSL,
IS_SECURETRANSPORT,
assert_fingerprint,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
PROTOCOL_TLS,
DEFAULT_ALPN_PROTOCOLS,
SUPPRESS_ALPN,
)
from .timeout import current_time, Timeout

Expand All @@ -23,10 +26,13 @@

__all__ = (
"HAS_SNI",
"HAS_ALPN",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"PROTOCOL_TLS",
"DEFAULT_ALPN_PROTOCOLS",
"SUPPRESS_ALPN",
"Retry",
"Timeout",
"Url",
Expand Down
13 changes: 13 additions & 0 deletions src/urllib3/util/ssl_.py
Expand Up @@ -15,8 +15,12 @@

SSLContext = None
HAS_SNI = False
HAS_ALPN = False
IS_PYOPENSSL = False
IS_SECURETRANSPORT = False
DEFAULT_ALPN_PROTOCOLS = ["http/1.1"]
#: A sentinel object to suppress the default ALPN protcols
SUPPRESS_ALPN = object()

# Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
Expand All @@ -41,6 +45,7 @@ def _const_compare_digest_backport(a, b):
import ssl
from ssl import wrap_socket, CERT_REQUIRED
from ssl import HAS_SNI # Has SNI?
from ssl import HAS_ALPN # Has ALPN?
except ImportError:
pass

Expand Down Expand Up @@ -316,6 +321,7 @@ def ssl_wrap_socket(
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
alpn_protocols=None,
):
"""
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
Expand All @@ -337,6 +343,8 @@ def ssl_wrap_socket(
:param ca_cert_data:
Optional string containing CA certificates in PEM format suitable for
passing as the cadata parameter to SSLContext.load_verify_locations()
:param alpn_protocols:
When ALPN is supported, the ALPN protocols to negotiate. :data:`SUPPRESS_ALPN` will suppress sending :data:`DEFAULT_ALPN_PROTOCOLS`.
"""
context = ssl_context
if context is None:
Expand Down Expand Up @@ -373,6 +381,11 @@ def ssl_wrap_socket(
else:
context.load_cert_chain(certfile, keyfile, key_password)

if HAS_ALPN and alpn_protocols is not SUPPRESS_ALPN:
if alpn_protocols is None:
alpn_protocols = DEFAULT_ALPN_PROTOCOLS
context.set_alpn_protocols(alpn_protocols)

# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
# We shouldn't warn the user if SNI isn't available but we would
Expand Down
41 changes: 41 additions & 0 deletions test/with_dummyserver/test_https.py
Expand Up @@ -805,3 +805,44 @@ def test_can_validate_ipv6_san(self, ipv6_san_server):
) as https_pool:
r = https_pool.request("GET", "/")
assert r.status == 200


class TestHTTPS_ALPN(TestHTTPS):
servers_last = "secondproto"
alpn_protos = util.DEFAULT_ALPN_PROTOCOLS + [servers_last]
servers_first = alpn_protos[0]
certs = dict(DEFAULT_CERTS, alpn_protocols=alpn_protos)

def _get_pool(self, **kwargs):
return HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA, **kwargs)

def test_alpn_custom(self):
"""Setting custom ALPN protocols chooses the right protocol."""
# choose the right protocol (server's first, client's last)
with self._get_pool(alpn_protocols=["fakeproto", self.servers_first]) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == self.servers_first
# choose the right protocol (client's first, server's last)
with self._get_pool(alpn_protocols=[self.servers_last, "fakeproto"]) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == self.servers_last
# don't choose a protocol
with self._get_pool(alpn_protocols=["fakeproto"]) as pool:
r = pool.request("GET", "/alpn_protocol", retries=0)
assert r.status == 200
assert r.data.decode("utf-8") == ""

def test_alpn_default(self):
"""Default ALPN protocols are sent by default, but can be suppressed."""
# sends default alpn protocols
with self._get_pool() as pool:
r = pool.request("GET", "/alpn_protocol", retries=0)
assert r.status == 200
assert r.data.decode("utf-8") == util.DEFAULT_ALPN_PROTOCOLS[0]
# can suppress default alpn protocols
with self._get_pool(alpn_protocols=util.SUPPRESS_ALPN) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == ""
26 changes: 26 additions & 0 deletions test/with_dummyserver/test_socketlevel.py
Expand Up @@ -112,6 +112,32 @@ def socket_handler(listener):
self.host.encode("ascii") in self.buf
), "missing hostname in SSL handshake"

def test_alpn_protocol_in_first_request_packet(self):
if not util.HAS_ALPN:
pytest.skip("ALPN-support not available")
done_receiving = Event()
self.buf = b""

def socket_handler(listener):
sock = listener.accept()[0]

self.buf = sock.recv(65536) # We only accept one packet
done_receiving.set() # let the test know it can proceed
sock.close()

self._start_server(socket_handler)
with HTTPSConnectionPool(self.host, self.port) as pool:
try:
pool.request("GET", "/", retries=0)
except MaxRetryError: # We are violating the protocol
pass
successful = done_receiving.wait(LONG_TIMEOUT)
assert successful, "Timed out waiting for connection accept"
for protocol in util.DEFAULT_ALPN_PROTOCOLS:
assert (
protocol.encode("ascii") in self.buf
), "missing ALPN protocol in SSL handshake"


class TestClientCerts(SocketDummyServerTestCase):
"""
Expand Down

0 comments on commit b6b41bd

Please sign in to comment.