Skip to content

Commit

Permalink
Don't raise SNI warning on SecureTransport with server_hostname=None
Browse files Browse the repository at this point in the history
  • Loading branch information
hodbn committed Aug 24, 2020
1 parent 0210ddb commit 16dc22b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 102 deletions.
21 changes: 12 additions & 9 deletions src/urllib3/util/ssl_.py
Expand Up @@ -382,14 +382,13 @@ def ssl_wrap_socket(

# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
# We shouldn't warn the user if SNI isn't available but we would
# not be using SNI anyways due to IP address for server_hostname.
if (
server_hostname is not None and not is_ipaddress(server_hostname)
) or IS_SECURETRANSPORT:
if HAS_SNI and server_hostname is not None:
return context.wrap_socket(sock, server_hostname=server_hostname)

use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
# SecureTransport uses server_hostname in certificate verification.
send_sni = (use_sni_hostname and HAS_SNI) or (
IS_SECURETRANSPORT and server_hostname
)
# Do not warn the user if server_hostname is an invalid SNI hostname.
if not HAS_SNI and use_sni_hostname:
warnings.warn(
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
Expand All @@ -401,7 +400,11 @@ def ssl_wrap_socket(
SNIMissingWarning,
)

return context.wrap_socket(sock)
if send_sni:
ssl_sock = context.wrap_socket(sock, server_hostname=server_hostname)
else:
ssl_sock = context.wrap_socket(sock)
return ssl_sock


def is_ipaddress(hostname):
Expand Down
1 change: 1 addition & 0 deletions test/contrib/test_pyopenssl.py
Expand Up @@ -30,6 +30,7 @@ def teardown_module():
pass


from ..test_util import TestUtilSSL # noqa: E402, F401
from ..with_dummyserver.test_https import ( # noqa: E402, F401
TestHTTPS,
TestHTTPS_TLSv1,
Expand Down
2 changes: 2 additions & 0 deletions test/contrib/test_securetransport.py
Expand Up @@ -29,6 +29,8 @@ def teardown_module():
pass


from ..test_util import TestUtilSSL # noqa: E402, F401

# SecureTransport does not support TLSv1.3
# https://github.com/urllib3/urllib3/issues/1674
from ..with_dummyserver.test_https import ( # noqa: E402, F401
Expand Down
210 changes: 117 additions & 93 deletions test/test_util.py
Expand Up @@ -10,7 +10,7 @@
from mock import patch, Mock
import pytest

from urllib3 import add_stderr_logger, disable_warnings
from urllib3 import add_stderr_logger, disable_warnings, util
from urllib3.util.request import make_headers, rewind_body, _FAILEDTELL
from urllib3.util.response import assert_header_parsing
from urllib3.util.timeout import Timeout
Expand All @@ -29,7 +29,7 @@
UnrewindableBodyError,
)
from urllib3.util.connection import allowed_gai_family, _has_ipv6
from urllib3.util import is_fp_closed, ssl_
from urllib3.util import is_fp_closed
from urllib3.packages import six

from . import clear_warnings
Expand Down Expand Up @@ -666,31 +666,6 @@ def test_timeout_elapsed(self, current_time):
current_time.return_value = TIMEOUT_EPOCH + 37
assert timeout.get_connect_duration() == 37

@pytest.mark.parametrize(
"candidate, requirements",
[
(None, ssl.CERT_REQUIRED),
(ssl.CERT_NONE, ssl.CERT_NONE),
(ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
("REQUIRED", ssl.CERT_REQUIRED),
("CERT_REQUIRED", ssl.CERT_REQUIRED),
],
)
def test_resolve_cert_reqs(self, candidate, requirements):
assert resolve_cert_reqs(candidate) == requirements

@pytest.mark.parametrize(
"candidate, version",
[
(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
("TLSv1", ssl.PROTOCOL_TLSv1),
(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
],
)
def test_resolve_ssl_version(self, candidate, version):
assert resolve_ssl_version(candidate) == version

def test_is_fp_closed_object_supports_closed(self):
class ClosedFile(object):
@property
Expand Down Expand Up @@ -722,72 +697,6 @@ class NotReallyAFile(object):
with pytest.raises(ValueError):
is_fp_closed(NotReallyAFile())

def test_ssl_wrap_socket_loads_the_cert_chain(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, sock=socket, certfile="/path/to/certfile"
)

mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None)

@patch("urllib3.util.ssl_.create_urllib3_context")
def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
socket = object()
ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED")

create_urllib3_context.assert_called_once_with(
None, "CERT_REQUIRED", ciphers=None
)

def test_ssl_wrap_socket_loads_verify_locations(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket)
mock_context.load_verify_locations.assert_called_once_with(
"/path/to/pem", None, None
)

def test_ssl_wrap_socket_loads_certificate_directories(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, "/path/to/pems", None
)

def test_ssl_wrap_socket_loads_certificate_data(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, None, "TOTALLY PEM DATA"
)

def test_ssl_wrap_socket_with_no_sni_warns(self):
socket = object()
mock_context = Mock()
# Ugly preservation of original value
HAS_SNI = ssl_.HAS_SNI
ssl_.HAS_SNI = False
try:
with patch("warnings.warn") as warn:
ssl_wrap_socket(
ssl_context=mock_context,
sock=socket,
server_hostname="www.google.com",
)
mock_context.wrap_socket.assert_called_once_with(socket)
assert warn.call_count >= 1
warnings = [call[0][1] for call in warn.call_args_list]
assert SNIMissingWarning in warnings
finally:
ssl_.HAS_SNI = HAS_SNI

def test_const_compare_digest_fallback(self):
target = hashlib.sha256(b"abcdef").digest()
assert _const_compare_digest_backport(target, target)
Expand Down Expand Up @@ -838,3 +747,118 @@ def test_ip_family_ipv6_disabled(self):
def test_assert_header_parsing_throws_typeerror_with_non_headers(self, headers):
with pytest.raises(TypeError):
assert_header_parsing(headers)


class TestUtilSSL(object):
"""Test utils that use an SSL backend."""

@pytest.mark.parametrize(
"candidate, requirements",
[
(None, ssl.CERT_REQUIRED),
(ssl.CERT_NONE, ssl.CERT_NONE),
(ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
("REQUIRED", ssl.CERT_REQUIRED),
("CERT_REQUIRED", ssl.CERT_REQUIRED),
],
)
def test_resolve_cert_reqs(self, candidate, requirements):
assert resolve_cert_reqs(candidate) == requirements

@pytest.mark.parametrize(
"candidate, version",
[
(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
("TLSv1", ssl.PROTOCOL_TLSv1),
(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
],
)
def test_resolve_ssl_version(self, candidate, version):
assert resolve_ssl_version(candidate) == version

def test_ssl_wrap_socket_loads_the_cert_chain(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, sock=socket, certfile="/path/to/certfile"
)

mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None)

@patch("urllib3.util.ssl_.create_urllib3_context")
def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
socket = object()
ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED")

create_urllib3_context.assert_called_once_with(
None, "CERT_REQUIRED", ciphers=None
)

def test_ssl_wrap_socket_loads_verify_locations(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket)
mock_context.load_verify_locations.assert_called_once_with(
"/path/to/pem", None, None
)

def test_ssl_wrap_socket_loads_certificate_directories(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, "/path/to/pems", None
)

def test_ssl_wrap_socket_loads_certificate_data(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, None, "TOTALLY PEM DATA"
)

def _wrap_socket_and_mock_warn(self, sock, server_hostname):
mock_context = Mock()
with patch("warnings.warn") as warn:
ssl_wrap_socket(
ssl_context=mock_context, sock=sock, server_hostname=server_hostname,
)
return mock_context, warn

def test_ssl_wrap_socket_sni_hostname_use_or_warn(self):
"""Test that either an SNI hostname is used or a warning is made."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, "www.google.com")
if util.HAS_SNI:
warn.assert_not_called()
context.wrap_socket.assert_called_once_with(
sock, server_hostname="www.google.com"
)
else:
assert warn.call_count >= 1
warnings = [call[0][1] for call in warn.call_args_list]
assert SNIMissingWarning in warnings
context.wrap_socket.assert_called_once_with(sock)

def test_ssl_wrap_socket_sni_ip_address_no_warn(self):
"""Test that a warning is not made if server_hostname is an IP address."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, "8.8.8.8")
if util.IS_SECURETRANSPORT:
context.wrap_socket.assert_called_once_with(sock, server_hostname="8.8.8.8")
else:
context.wrap_socket.assert_called_once_with(sock)
warn.assert_not_called()

def test_ssl_wrap_socket_sni_none_no_warn(self):
"""Test that a warning is not made if server_hostname is not given."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, None)
context.wrap_socket.assert_called_once_with(sock)
warn.assert_not_called()

0 comments on commit 16dc22b

Please sign in to comment.