Skip to content

Commit

Permalink
Simplify ALPN support
Browse files Browse the repository at this point in the history
  • Loading branch information
hodbn committed Jul 5, 2020
1 parent c7c4c5c commit 9457440
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 89 deletions.
10 changes: 7 additions & 3 deletions dummyserver/server.py
Expand Up @@ -15,7 +15,7 @@
from datetime import datetime

from urllib3.exceptions import HTTPWarning
from urllib3.util import resolve_cert_reqs, resolve_ssl_version
from urllib3.util import resolve_cert_reqs, resolve_ssl_version, ALPN_PROTOCOLS

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
Expand All @@ -34,6 +34,7 @@
"keyfile": os.path.join(CERTS_PATH, "server.key"),
"cert_reqs": ssl.CERT_OPTIONAL,
"ca_certs": os.path.join(CERTS_PATH, "cacert.pem"),
"alpn_protocols": ALPN_PROTOCOLS,
}
DEFAULT_CA = os.path.join(CERTS_PATH, "cacert.pem")
DEFAULT_CA_KEY = os.path.join(CERTS_PATH, "cacert.key")
Expand Down Expand Up @@ -159,8 +160,11 @@ def ssl_options_to_context(
ctx.verify_mode = cert_reqs
if ctx.verify_mode != cert_none:
ctx.load_verify_locations(cafile=ca_certs)
if alpn_protocols:
ctx.set_alpn_protocols(alpn_protocols)
if alpn_protocols and hasattr(ctx, "set_alpn_protocols"):
try:
ctx.set_alpn_protocols(alpn_protocols)
except NotImplementedError:
pass
return ctx


Expand Down
3 changes: 0 additions & 3 deletions src/urllib3/connection.py
Expand Up @@ -283,7 +283,6 @@ def set_cert(
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
alpn_protocols=None,
):
"""
This method should only be called once, before the connection is used.
Expand All @@ -305,7 +304,6 @@ def set_cert(
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
self.alpn_protocols = alpn_protocols

def connect(self):
# Add certificate verification
Expand Down Expand Up @@ -372,7 +370,6 @@ def connect(self):
ca_cert_data=self.ca_cert_data,
server_hostname=server_hostname,
ssl_context=context,
alpn_protocols=self.alpn_protocols,
)

if self.assert_fingerprint:
Expand Down
5 changes: 1 addition & 4 deletions src/urllib3/connectionpool.py
Expand Up @@ -846,7 +846,7 @@ class HTTPSConnectionPool(HTTPConnectionPool):
If ``assert_hostname`` is False, no verification is done.
The ``key_file``, ``cert_file``, ``cert_reqs``, ``ca_certs``,
``ca_cert_dir``, ``ssl_version``, ``key_password``, ``alpn_protocols`` are only used if :mod:`ssl`
``ca_cert_dir``, ``ssl_version``, ``key_password`` are only used if :mod:`ssl`
is available and are fed into :meth:`urllib3.util.ssl_wrap_socket` to upgrade
the connection socket into an SSL socket.
"""
Expand Down Expand Up @@ -875,7 +875,6 @@ def __init__(
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
alpn_protocols=None,
**conn_kw
):

Expand Down Expand Up @@ -903,7 +902,6 @@ def __init__(
self.ssl_version = ssl_version
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint
self.alpn_protocols = alpn_protocols

def _prepare_conn(self, conn):
"""
Expand All @@ -921,7 +919,6 @@ def _prepare_conn(self, conn):
ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint,
alpn_protocols=self.alpn_protocols,
)
conn.ssl_version = self.ssl_version
return conn
Expand Down
7 changes: 0 additions & 7 deletions src/urllib3/contrib/pyopenssl.py
Expand Up @@ -78,8 +78,6 @@ class UnsupportedExtension(Exception):
# SNI always works.
HAS_SNI = True

HAS_ALPN = hasattr(OpenSSL.SSL.Context, "set_alpn_protos")

# Map from urllib3 to PyOpenSSL compatible parameter-values.
_openssl_versions = {
util.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD,
Expand Down Expand Up @@ -108,7 +106,6 @@ class UnsupportedExtension(Exception):
SSL_WRITE_BLOCKSIZE = 16384

orig_util_HAS_SNI = util.HAS_SNI
orig_util_HAS_ALPN = util.HAS_ALPN
orig_util_SSLContext = util.ssl_.SSLContext


Expand All @@ -124,8 +121,6 @@ def inject_into_urllib3():
util.ssl_.SSLContext = PyOpenSSLContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.HAS_ALPN = HAS_ALPN
util.ssl_.HAS_ALPN = HAS_ALPN
util.IS_PYOPENSSL = True
util.ssl_.IS_PYOPENSSL = True

Expand All @@ -137,8 +132,6 @@ def extract_from_urllib3():
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.HAS_ALPN = orig_util_HAS_ALPN
util.ssl_.HAS_ALPN = orig_util_HAS_ALPN
util.IS_PYOPENSSL = False
util.ssl_.IS_PYOPENSSL = False

Expand Down
9 changes: 0 additions & 9 deletions src/urllib3/contrib/securetransport.py
Expand Up @@ -81,12 +81,7 @@
# SNI always works
HAS_SNI = True

# TODO: ALPN is currently not implemented.
# See https://developer.apple.com/documentation/security/2976269-sec_protocol_options_add_tls_app
HAS_ALPN = False

orig_util_HAS_SNI = util.HAS_SNI
orig_util_HAS_ALPN = util.HAS_ALPN
orig_util_SSLContext = util.ssl_.SSLContext

# This dictionary is used by the read callback to obtain a handle to the
Expand Down Expand Up @@ -190,8 +185,6 @@ def inject_into_urllib3():
util.ssl_.SSLContext = SecureTransportContext
util.HAS_SNI = HAS_SNI
util.ssl_.HAS_SNI = HAS_SNI
util.HAS_ALPN = HAS_ALPN
util.ssl_.HAS_ALPN = HAS_ALPN
util.IS_SECURETRANSPORT = True
util.ssl_.IS_SECURETRANSPORT = True

Expand All @@ -204,8 +197,6 @@ def extract_from_urllib3():
util.ssl_.SSLContext = orig_util_SSLContext
util.HAS_SNI = orig_util_HAS_SNI
util.ssl_.HAS_SNI = orig_util_HAS_SNI
util.HAS_ALPN = orig_util_HAS_ALPN
util.ssl_.HAS_ALPN = orig_util_HAS_ALPN
util.IS_SECURETRANSPORT = False
util.ssl_.IS_SECURETRANSPORT = False

Expand Down
2 changes: 0 additions & 2 deletions src/urllib3/poolmanager.py
Expand Up @@ -42,7 +42,6 @@ class InvalidProxyConfigurationWarning(HTTPWarning):
"ca_cert_dir",
"ssl_context",
"key_password",
"alpn_protocols",
)

# All known keyword arguments that could be provided to the pool manager, its
Expand Down Expand Up @@ -73,7 +72,6 @@ class InvalidProxyConfigurationWarning(HTTPWarning):
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
"key_alpn_protocols", # list of str
)

#: The namedtuple class used to construct keys for the connection pool.
Expand Down
10 changes: 4 additions & 6 deletions src/urllib3/util/__init__.py
Expand Up @@ -7,16 +7,15 @@
from .ssl_ import (
SSLContext,
HAS_SNI,
HAS_ALPN,
IS_PYOPENSSL,
IS_SECURETRANSPORT,
assert_fingerprint,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
has_alpn,
PROTOCOL_TLS,
DEFAULT_ALPN_PROTOCOLS,
SUPPRESS_ALPN,
ALPN_PROTOCOLS,
)
from .timeout import current_time, Timeout

Expand All @@ -26,13 +25,11 @@

__all__ = (
"HAS_SNI",
"HAS_ALPN",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"PROTOCOL_TLS",
"DEFAULT_ALPN_PROTOCOLS",
"SUPPRESS_ALPN",
"ALPN_PROTOCOLS",
"Retry",
"Timeout",
"Url",
Expand All @@ -47,6 +44,7 @@
"resolve_ssl_version",
"split_first",
"ssl_wrap_socket",
"has_alpn",
"wait_for_read",
"wait_for_write",
)
31 changes: 19 additions & 12 deletions src/urllib3/util/ssl_.py
Expand Up @@ -15,12 +15,9 @@

SSLContext = None
HAS_SNI = False
HAS_ALPN = False
IS_PYOPENSSL = False
IS_SECURETRANSPORT = False
DEFAULT_ALPN_PROTOCOLS = ["http/1.1"]
#: A sentinel object to suppress the default ALPN protcols
SUPPRESS_ALPN = object()
ALPN_PROTOCOLS = ["http/1.1"]

# Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
Expand All @@ -45,7 +42,6 @@ def _const_compare_digest_backport(a, b):
import ssl
from ssl import wrap_socket, CERT_REQUIRED
from ssl import HAS_SNI # Has SNI?
from ssl import HAS_ALPN # Has ALPN?
except ImportError:
pass

Expand Down Expand Up @@ -321,7 +317,6 @@ def ssl_wrap_socket(
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
alpn_protocols=None,
):
"""
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
Expand All @@ -343,8 +338,6 @@ def ssl_wrap_socket(
:param ca_cert_data:
Optional string containing CA certificates in PEM format suitable for
passing as the cadata parameter to SSLContext.load_verify_locations()
:param alpn_protocols:
When ALPN is supported, the ALPN protocols to negotiate. :data:`SUPPRESS_ALPN` will suppress sending :data:`DEFAULT_ALPN_PROTOCOLS`.
"""
context = ssl_context
if context is None:
Expand Down Expand Up @@ -381,10 +374,11 @@ def ssl_wrap_socket(
else:
context.load_cert_chain(certfile, keyfile, key_password)

if HAS_ALPN and alpn_protocols is not SUPPRESS_ALPN:
if alpn_protocols is None:
alpn_protocols = DEFAULT_ALPN_PROTOCOLS
context.set_alpn_protocols(alpn_protocols)
try:
if hasattr(context, "set_alpn_protocols"):
context.set_alpn_protocols(ALPN_PROTOCOLS)
except NotImplementedError:
pass

# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
Expand All @@ -410,6 +404,19 @@ def ssl_wrap_socket(
return context.wrap_socket(sock)


def has_alpn(ctx_cls=None):
"""Detect if ALPN support is enabled."""
ctx_cls = ctx_cls or SSLContext
ctx = ctx_cls(protocol=PROTOCOL_TLS)
try:
if hasattr(ctx, "set_alpn_protocols"):
ctx.set_alpn_protocols(ALPN_PROTOCOLS)
return True
except NotImplementedError:
pass
return False


def is_ipaddress(hostname):
"""Detects whether the hostname given is an IPv4 or IPv6 address.
Also detects IPv6 addresses with Zone IDs.
Expand Down
50 changes: 9 additions & 41 deletions test/with_dummyserver/test_https.py
Expand Up @@ -717,6 +717,15 @@ def test_sslkeylogfile(self, tmpdir, monkeypatch):
% str(keylog_file)
)

def test_alpn_default(self):
"""Default ALPN protocols are sent by default."""
if not util.has_alpn() or not util.has_alpn(ssl.SSLContext):
pytest.skip("ALPN-support not available")
with HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA) as pool:
r = pool.request("GET", "/alpn_protocol", retries=0)
assert r.status == 200
assert r.data.decode("utf-8") == util.ALPN_PROTOCOLS[0]


@requiresTLSv1()
class TestHTTPS_TLSv1(TestHTTPS):
Expand Down Expand Up @@ -805,44 +814,3 @@ def test_can_validate_ipv6_san(self, ipv6_san_server):
) as https_pool:
r = https_pool.request("GET", "/")
assert r.status == 200


class TestHTTPS_ALPN(TestHTTPS):
servers_last = "secondproto"
alpn_protos = util.DEFAULT_ALPN_PROTOCOLS + [servers_last]
servers_first = alpn_protos[0]
certs = dict(DEFAULT_CERTS, alpn_protocols=alpn_protos)

def _get_pool(self, **kwargs):
return HTTPSConnectionPool(self.host, self.port, ca_certs=DEFAULT_CA, **kwargs)

def test_alpn_custom(self):
"""Setting custom ALPN protocols chooses the right protocol."""
# choose the right protocol (server's first, client's last)
with self._get_pool(alpn_protocols=["fakeproto", self.servers_first]) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == self.servers_first
# choose the right protocol (client's first, server's last)
with self._get_pool(alpn_protocols=[self.servers_last, "fakeproto"]) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == self.servers_last
# don't choose a protocol
with self._get_pool(alpn_protocols=["fakeproto"]) as pool:
r = pool.request("GET", "/alpn_protocol", retries=0)
assert r.status == 200
assert r.data.decode("utf-8") == ""

def test_alpn_default(self):
"""Default ALPN protocols are sent by default, but can be suppressed."""
# sends default alpn protocols
with self._get_pool() as pool:
r = pool.request("GET", "/alpn_protocol", retries=0)
assert r.status == 200
assert r.data.decode("utf-8") == util.DEFAULT_ALPN_PROTOCOLS[0]
# can suppress default alpn protocols
with self._get_pool(alpn_protocols=util.SUPPRESS_ALPN) as pool:
r = pool.request("GET", "/alpn_protocol")
assert r.status == 200
assert r.data.decode("utf-8") == ""
7 changes: 5 additions & 2 deletions test/with_dummyserver/test_socketlevel.py
Expand Up @@ -112,9 +112,12 @@ def socket_handler(listener):
self.host.encode("ascii") in self.buf
), "missing hostname in SSL handshake"


class TestALPN(SocketDummyServerTestCase):
def test_alpn_protocol_in_first_request_packet(self):
if not util.HAS_ALPN:
if not util.has_alpn():
pytest.skip("ALPN-support not available")

done_receiving = Event()
self.buf = b""

Expand All @@ -133,7 +136,7 @@ def socket_handler(listener):
pass
successful = done_receiving.wait(LONG_TIMEOUT)
assert successful, "Timed out waiting for connection accept"
for protocol in util.DEFAULT_ALPN_PROTOCOLS:
for protocol in util.ALPN_PROTOCOLS:
assert (
protocol.encode("ascii") in self.buf
), "missing ALPN protocol in SSL handshake"
Expand Down

0 comments on commit 9457440

Please sign in to comment.