Skip to content

Commit

Permalink
Start support for PyOpenSSL and SecureTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Jun 21, 2021
1 parent 8ab2765 commit d302f57
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 7 deletions.
10 changes: 10 additions & 0 deletions src/urllib3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

if TYPE_CHECKING:
from typing_extensions import Literal
from ssl import TLSVersion

from .util.proxy import create_proxy_ssl_context
from .util.util import to_str
Expand Down Expand Up @@ -341,6 +342,8 @@ class HTTPSConnection(HTTPConnection):
ca_cert_dir: Optional[str] = None
ca_cert_data: Union[None, str, bytes] = None
ssl_version: Optional[Union[int, str]] = None
ssl_minimum_version: Optional["TLSVersion"] = None
ssl_maximum_version: Optional["TLSVersion"] = None
assert_fingerprint: Optional[str] = None
tls_in_tls_required: bool = False

Expand Down Expand Up @@ -379,6 +382,9 @@ def __init__(
self.key_password = key_password
self.ssl_context = ssl_context
self.server_hostname = server_hostname
self.ssl_version = None
self.ssl_minimum_version = None
self.ssl_maximum_version = None

def set_cert(
self,
Expand Down Expand Up @@ -458,6 +464,8 @@ def connect(self) -> None:
default_ssl_context = True
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(self.ssl_version),
ssl_minimum_version=self.ssl_minimum_version,
ssl_maximum_version=self.ssl_maximum_version,
cert_reqs=resolve_cert_reqs(self.cert_reqs),
)
# In some cases, we want to verify hostnames ourselves
Expand Down Expand Up @@ -513,6 +521,8 @@ def connect(self) -> None:
if (
default_ssl_context
and self.ssl_version is None
and self.ssl_minimum_version is None
and self.ssl_maximum_version is None
and tls_version is not None
and tls_version in {"TLSv1", "TLSv1.1"}
):
Expand Down
11 changes: 10 additions & 1 deletion src/urllib3/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@
from .util.url import _normalize_host as normalize_host
from .util.url import parse_url
from .util.util import to_str
from .util.ssl_ import _TYPE_TLS_VERSION

if TYPE_CHECKING:

from typing_extensions import Literal

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -895,6 +897,8 @@ def __init__(
key_password: Optional[str] = None,
ca_certs: Optional[str] = None,
ssl_version: Optional[Union[int, str]] = None,
ssl_minimum_version: Optional[_TYPE_TLS_VERSION]=None,
ssl_maximum_version: Optional[_TYPE_TLS_VERSION]=None,
assert_hostname: Optional[Union[str, "Literal[False]"]] = None,
assert_fingerprint: Optional[str] = None,
ca_cert_dir: Optional[str] = None,
Expand All @@ -921,6 +925,8 @@ def __init__(
self.ca_certs = ca_certs
self.ca_cert_dir = ca_cert_dir
self.ssl_version = ssl_version
self.ssl_minimum_version = ssl_minimum_version
self.ssl_maximum_version = ssl_maximum_version
self.assert_hostname = assert_hostname
self.assert_fingerprint = assert_fingerprint

Expand All @@ -930,7 +936,7 @@ def _prepare_conn(self, conn: HTTPSConnection) -> HTTPConnection:
and establish the tunnel if proxy is used.
"""

if isinstance(conn, VerifiedHTTPSConnection):
if isinstance(conn, HTTPSConnection):
conn.set_cert(
key_file=self.key_file,
key_password=self.key_password,
Expand All @@ -942,6 +948,9 @@ def _prepare_conn(self, conn: HTTPSConnection) -> HTTPConnection:
assert_fingerprint=self.assert_fingerprint,
)
conn.ssl_version = self.ssl_version
conn.ssl_minimum_version = self.ssl_minimum_version
conn.ssl_maximum_version = self.ssl_maximum_version

return conn

def _prepare_proxy(self, conn: HTTPSConnection) -> None: # type: ignore
Expand Down
79 changes: 78 additions & 1 deletion src/urllib3/contrib/pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,53 @@ class UnsupportedExtension(Exception):
}
_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()}

try:
from ssl import TLSVersion
except ImportError:
TLSVersion = None # type: ignore

if TLSVersion:
# The SSLvX values are the most likely to be missing in the future
# but we check them all just to be sure.
OP_NO_SSLv2 = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0)
OP_NO_SSLv3 = getattr(OpenSSL.SSL, "OP_NO_SSLv3", 0)
OP_NO_TLSv1 = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0)
OP_NO_TLSv1_1 = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0)
OP_NO_TLSv1_2 = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0)
OP_NO_TLSv1_3 = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0)

_openssl_to_ssl_minimum_version = {
ssl.TLSVersion.MINIMUM_SUPPORTED: OP_NO_SSLv2,
ssl.TLSVersion.SSLv3: OP_NO_SSLv2,
ssl.TLSVersion.TLSv1: OP_NO_SSLv2 | OP_NO_SSLv3,
ssl.TLSVersion.TLSv1_1: OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_TLSv1,
ssl.TLSVersion.TLSv1_2: OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_TLSv1 | OP_NO_TLSv1_1,
ssl.TLSVersion.TLSv1_3: (
OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_TLSv1 | OP_NO_TLSv1_1 | OP_NO_TLSv1_2
),
ssl.TLSVersion.MAXIMUM_SUPPORTED: (
OP_NO_SSLv2 | OP_NO_SSLv3 | OP_NO_TLSv1 | OP_NO_TLSv1_1 | OP_NO_TLSv1_2
),
}
_openssl_to_ssl_maximum_version = {
ssl.TLSVersion.MINIMUM_SUPPORTED: (
OP_NO_SSLv2 | OP_NO_TLSv1 | OP_NO_TLSv1_1 | OP_NO_TLSv1_2 | OP_NO_TLSv1_3
),
ssl.TLSVersion.SSLv3: (
OP_NO_SSLv2 | OP_NO_TLSv1 | OP_NO_TLSv1_1 | OP_NO_TLSv1_2 | OP_NO_TLSv1_3
),
ssl.TLSVersion.TLSv1: (
OP_NO_SSLv2 | OP_NO_TLSv1_1 | OP_NO_TLSv1_2 | OP_NO_TLSv1_3
),
ssl.TLSVersion.TLSv1_1: OP_NO_SSLv2 | OP_NO_TLSv1_2 | OP_NO_TLSv1_3,
ssl.TLSVersion.TLSv1_2: OP_NO_SSLv2 | OP_NO_TLSv1_3,
ssl.TLSVersion.TLSv1_3: OP_NO_SSLv2,
ssl.TLSVersion.MAXIMUM_SUPPORTED: OP_NO_SSLv2,
}
else:
_openssl_to_ssl_minimum_version = {}
_openssl_to_ssl_maximum_version = {}

# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384

Expand Down Expand Up @@ -397,6 +444,9 @@ def __init__(self, protocol):
self._ctx = OpenSSL.SSL.Context(self.protocol)
self._options = 0
self.check_hostname = False
self._minimum_version = None
self._maximum_version = None
self._ssl_version_options = 0

@property
def options(self):
Expand All @@ -405,7 +455,7 @@ def options(self):
@options.setter
def options(self, value):
self._options = value
self._ctx.set_options(value)
self._set_ctx_options()

@property
def verify_mode(self):
Expand Down Expand Up @@ -478,6 +528,33 @@ def wrap_socket(

return WrappedSocket(cnx, sock)

def _set_ctx_options(self):
self._ctx.set_options(
self._options
| _openssl_to_ssl_minimum_version.get(self._minimum_version, 0)
| _openssl_to_ssl_maximum_version.get(self._maximum_version, 0)
)

if TLSVersion:

@property
def minimum_version(self):
return self._minimum_version

@minimum_version.setter
def minimum_version(self, minimum_version):
self._minimum_version = minimum_version
self._set_ctx_options()

@property
def maximum_version(self):
return self._maximum_version

@maximum_version.setter
def maximum_version(self, maximum_version):
self._maximum_version = maximum_version
self._set_ctx_options()


def _verify_callback(cnx, x509, err_no, err_depth, return_code):
return err_no == 0
43 changes: 43 additions & 0 deletions src/urllib3/contrib/securetransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import threading
import weakref
from socket import socket as socket_cls
from typing import Optional

from .. import util
from ._securetransport.bindings import CoreFoundation, Security, SecurityConst
Expand Down Expand Up @@ -144,6 +145,30 @@
SecurityConst.kTLSProtocol12,
)

try:
from ssl import TLSVersion
except ImportError:
TLSVersion = None # type: ignore


if TLSVersion:
_tls_version_to_st = {
TLSVersion.MINIMUM_SUPPORTED: SecurityConst.kTLSProtocol1,
TLSVersion.TLSv1: SecurityConst.kTLSProtocol1,
TLSVersion.TLSv1_1: SecurityConst.kTLSProtocol11,
TLSVersion.TLSv1_2: SecurityConst.kTLSProtocol12,
TLSVersion.MAXIMUM_SUPPORTED: SecurityConst.kTLSProtocol12,
}
# Reverse mapping for mapping back to TLSVersion once set.
_st_to_tls_version = {
v: k
for k, v in _tls_version_to_st.items()
if k not in (TLSVersion.MINIMUM_SUPPORTED, TLSVersion.MAXIMUM_SUPPORTED)
}
else:
_tls_version_to_st = {}
_st_to_tls_version = {}


def inject_into_urllib3():
"""
Expand Down Expand Up @@ -844,3 +869,21 @@ def wrap_socket(
self._alpn_protocols,
)
return wrapped_socket

if TLSVersion:

@property
def minimum_version(self) -> "TLSVersion":
return _st_to_tls_version[self._min_version]

@minimum_version.setter
def minimum_version(self, minimum_version: Optional["TLSVersion"]) -> None:
self._min_version = _tls_version_to_st[minimum_version]

@property
def maximum_version(self) -> "TLSVersion":
return _st_to_tls_version[self._max_version]

@maximum_version.setter
def maximum_version(self, maximum_version: Optional["TLSVersion"]) -> None:
self._max_version = _tls_version_to_st[maximum_version]
5 changes: 3 additions & 2 deletions src/urllib3/poolmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .util.retry import Retry
from .util.timeout import Timeout
from .util.url import Url, parse_url
from .util.ssl_ import _TYPE_TLS_VERSION

if TYPE_CHECKING:
import ssl
Expand Down Expand Up @@ -81,8 +82,8 @@ class PoolKey(NamedTuple):
key_cert_reqs: Optional[str]
key_ca_certs: Optional[str]
key_ssl_version: Optional[Union[int, str]]
key_ssl_minimum_version: Optional["ssl.TLSVersion"]
key_ssl_maximum_version: Optional["ssl.TLSVersion"]
key_ssl_minimum_version: Optional[_TYPE_TLS_VERSION]
key_ssl_maximum_version: Optional[_TYPE_TLS_VERSION]
key_ca_cert_dir: Optional[str]
key_ssl_context: Optional["ssl.SSLContext"]
key_maxsize: Optional[int]
Expand Down
1 change: 1 addition & 0 deletions src/urllib3/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def request(
body: Optional[HTTPBody] = None,
fields: Optional[_TYPE_FIELDS] = None,
headers: Optional[Mapping[str, str]] = None,
body: Optional[HTTPBody] = None,
**urlopen_kw: Any,
) -> BaseHTTPResponse:
"""
Expand Down
3 changes: 3 additions & 0 deletions src/urllib3/util/ssl_.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def _is_ge_openssl_v1_1_1(

try: # Python 3.7+
from ssl import TLSVersion

_TYPE_TLS_VERSION = TLSVersion
except ImportError: # Python 3.6
TLSVersion = None # type: ignore
_TYPE_TLS_VERSION = None # type: ignore


_PCTRTT = Tuple[Tuple[str, str], ...]
Expand Down
4 changes: 1 addition & 3 deletions test/contrib/test_pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ def teardown_module():

from ..test_ssl import TestSSL # noqa: E402, F401
from ..test_util import TestUtilSSL # noqa: E402, F401
from ..with_dummyserver.test_https import ( # noqa: E402, F401
from ..with_dummyserver.test_https import ( # noqa: E402, F401; TestHTTPS_TLSv1,; TestHTTPS_TLSv1_1,
TestHTTPS,
TestHTTPS_IPSAN,
TestHTTPS_IPV6SAN,
TestHTTPS_TLSv1,
TestHTTPS_TLSv1_1,
TestHTTPS_TLSv1_2,
TestHTTPS_TLSv1_3,
)
Expand Down
34 changes: 34 additions & 0 deletions test/with_dummyserver/test_https.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class TestHTTPS(HTTPSDummyServerTestCase):
def tls_protocol_deprecated(self):
return self.tls_protocol_name in {"TLSv1", "TLSv1.1"}

def tls_version(self):
try:
from ssl import TLSVersion
except ImportError:
return pytest.skip("ssl.TLSVersion isn't available")
return getattr(TLSVersion, self.tls_protocol_name.replace(".", "_"))

@classmethod
def setup_class(cls):
super().setup_class()
Expand Down Expand Up @@ -719,6 +726,33 @@ def test_no_tls_version_deprecation_with_ssl_context(self):

assert w == []

def test_tls_version_maximum_and_minimum(self):
if self.tls_protocol_name is None:
pytest.skip("Skipping base test class")

from ssl import TLSVersion

min_max_versions = [
(self.tls_version(), self.tls_version()),
(TLSVersion.MINIMUM_SUPPORTED, self.tls_version()),
(TLSVersion.MINIMUM_SUPPORTED, TLSVersion.MAXIMUM_SUPPORTED),
]

for minimum_version, maximum_version in min_max_versions:
with HTTPSConnectionPool(
self.host,
self.port,
ca_certs=DEFAULT_CA,
ssl_minimum_version=minimum_version,
ssl_maximum_version=maximum_version,
) as https_pool:
conn = https_pool._get_conn()
try:
conn.connect()
assert conn.sock.version() == self.tls_protocol_name
finally:
conn.close()

@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python 3.8+")
def test_sslkeylogfile(self, tmpdir, monkeypatch):
if not hasattr(util.SSLContext, "keylog_filename"):
Expand Down

0 comments on commit d302f57

Please sign in to comment.