Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unnecessary SNI warning with SecureTransport #1903

Merged
merged 3 commits into from Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/urllib3/util/ssl_.py
Expand Up @@ -375,14 +375,13 @@ def ssl_wrap_socket(

# If we detect server_hostname is an IP address then the SNI
# extension should not be used according to RFC3546 Section 3.1
# We shouldn't warn the user if SNI isn't available but we would
# not be using SNI anyways due to IP address for server_hostname.
if (
server_hostname is not None and not is_ipaddress(server_hostname)
) or IS_SECURETRANSPORT:
if HAS_SNI and server_hostname is not None:
return context.wrap_socket(sock, server_hostname=server_hostname)

use_sni_hostname = server_hostname and not is_ipaddress(server_hostname)
# SecureTransport uses server_hostname in certificate verification.
send_sni = (use_sni_hostname and HAS_SNI) or (
IS_SECURETRANSPORT and server_hostname
)
# Do not warn the user if server_hostname is an invalid SNI hostname.
if not HAS_SNI and use_sni_hostname:
warnings.warn(
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
Expand All @@ -394,7 +393,11 @@ def ssl_wrap_socket(
SNIMissingWarning,
)

return context.wrap_socket(sock)
if send_sni:
ssl_sock = context.wrap_socket(sock, server_hostname=server_hostname)
else:
ssl_sock = context.wrap_socket(sock)
return ssl_sock


def is_ipaddress(hostname):
Expand Down
1 change: 1 addition & 0 deletions test/contrib/test_pyopenssl.py
Expand Up @@ -30,6 +30,7 @@ def teardown_module():
pass


from ..test_util import TestUtilSSL # noqa: E402, F401
from ..with_dummyserver.test_https import ( # noqa: E402, F401
TestHTTPS,
TestHTTPS_TLSv1,
Expand Down
2 changes: 2 additions & 0 deletions test/contrib/test_securetransport.py
Expand Up @@ -29,6 +29,8 @@ def teardown_module():
pass


from ..test_util import TestUtilSSL # noqa: E402, F401

# SecureTransport does not support TLSv1.3
# https://github.com/urllib3/urllib3/issues/1674
from ..with_dummyserver.test_https import ( # noqa: E402, F401
Expand Down
210 changes: 117 additions & 93 deletions test/test_util.py
Expand Up @@ -10,7 +10,7 @@
from mock import patch, Mock
import pytest

from urllib3 import add_stderr_logger, disable_warnings
from urllib3 import add_stderr_logger, disable_warnings, util
from urllib3.util.request import make_headers, rewind_body, _FAILEDTELL
from urllib3.util.response import assert_header_parsing
from urllib3.util.timeout import Timeout
Expand All @@ -29,7 +29,7 @@
UnrewindableBodyError,
)
from urllib3.util.connection import allowed_gai_family, _has_ipv6
from urllib3.util import is_fp_closed, ssl_
from urllib3.util import is_fp_closed
from urllib3.packages import six

from . import clear_warnings
Expand Down Expand Up @@ -666,31 +666,6 @@ def test_timeout_elapsed(self, current_time):
current_time.return_value = TIMEOUT_EPOCH + 37
assert timeout.get_connect_duration() == 37

@pytest.mark.parametrize(
"candidate, requirements",
[
(None, ssl.CERT_REQUIRED),
(ssl.CERT_NONE, ssl.CERT_NONE),
(ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
("REQUIRED", ssl.CERT_REQUIRED),
("CERT_REQUIRED", ssl.CERT_REQUIRED),
],
)
def test_resolve_cert_reqs(self, candidate, requirements):
assert resolve_cert_reqs(candidate) == requirements

@pytest.mark.parametrize(
"candidate, version",
[
(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
("TLSv1", ssl.PROTOCOL_TLSv1),
(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
],
)
def test_resolve_ssl_version(self, candidate, version):
assert resolve_ssl_version(candidate) == version

def test_is_fp_closed_object_supports_closed(self):
class ClosedFile(object):
@property
Expand Down Expand Up @@ -722,72 +697,6 @@ class NotReallyAFile(object):
with pytest.raises(ValueError):
is_fp_closed(NotReallyAFile())

def test_ssl_wrap_socket_loads_the_cert_chain(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, sock=socket, certfile="/path/to/certfile"
)

mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None)

@patch("urllib3.util.ssl_.create_urllib3_context")
def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
socket = object()
ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED")

create_urllib3_context.assert_called_once_with(
None, "CERT_REQUIRED", ciphers=None
)

def test_ssl_wrap_socket_loads_verify_locations(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket)
mock_context.load_verify_locations.assert_called_once_with(
"/path/to/pem", None, None
)

def test_ssl_wrap_socket_loads_certificate_directories(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, "/path/to/pems", None
)

def test_ssl_wrap_socket_loads_certificate_data(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, None, "TOTALLY PEM DATA"
)

def test_ssl_wrap_socket_with_no_sni_warns(self):
socket = object()
mock_context = Mock()
# Ugly preservation of original value
HAS_SNI = ssl_.HAS_SNI
ssl_.HAS_SNI = False
try:
with patch("warnings.warn") as warn:
ssl_wrap_socket(
ssl_context=mock_context,
sock=socket,
server_hostname="www.google.com",
)
mock_context.wrap_socket.assert_called_once_with(socket)
assert warn.call_count >= 1
warnings = [call[0][1] for call in warn.call_args_list]
assert SNIMissingWarning in warnings
finally:
ssl_.HAS_SNI = HAS_SNI

def test_const_compare_digest_fallback(self):
target = hashlib.sha256(b"abcdef").digest()
assert _const_compare_digest_backport(target, target)
Expand Down Expand Up @@ -838,3 +747,118 @@ def test_ip_family_ipv6_disabled(self):
def test_assert_header_parsing_throws_typeerror_with_non_headers(self, headers):
with pytest.raises(TypeError):
assert_header_parsing(headers)


class TestUtilSSL(object):
"""Test utils that use an SSL backend."""

@pytest.mark.parametrize(
"candidate, requirements",
[
(None, ssl.CERT_REQUIRED),
(ssl.CERT_NONE, ssl.CERT_NONE),
(ssl.CERT_REQUIRED, ssl.CERT_REQUIRED),
("REQUIRED", ssl.CERT_REQUIRED),
("CERT_REQUIRED", ssl.CERT_REQUIRED),
],
)
def test_resolve_cert_reqs(self, candidate, requirements):
assert resolve_cert_reqs(candidate) == requirements

@pytest.mark.parametrize(
"candidate, version",
[
(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1),
("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1),
("TLSv1", ssl.PROTOCOL_TLSv1),
(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23),
],
)
def test_resolve_ssl_version(self, candidate, version):
assert resolve_ssl_version(candidate) == version

def test_ssl_wrap_socket_loads_the_cert_chain(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, sock=socket, certfile="/path/to/certfile"
)

mock_context.load_cert_chain.assert_called_once_with("/path/to/certfile", None)

@patch("urllib3.util.ssl_.create_urllib3_context")
def test_ssl_wrap_socket_creates_new_context(self, create_urllib3_context):
socket = object()
ssl_wrap_socket(sock=socket, cert_reqs="CERT_REQUIRED")

create_urllib3_context.assert_called_once_with(
None, "CERT_REQUIRED", ciphers=None
)

def test_ssl_wrap_socket_loads_verify_locations(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(ssl_context=mock_context, ca_certs="/path/to/pem", sock=socket)
mock_context.load_verify_locations.assert_called_once_with(
"/path/to/pem", None, None
)

def test_ssl_wrap_socket_loads_certificate_directories(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_dir="/path/to/pems", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, "/path/to/pems", None
)

def test_ssl_wrap_socket_loads_certificate_data(self):
socket = object()
mock_context = Mock()
ssl_wrap_socket(
ssl_context=mock_context, ca_cert_data="TOTALLY PEM DATA", sock=socket
)
mock_context.load_verify_locations.assert_called_once_with(
None, None, "TOTALLY PEM DATA"
)

def _wrap_socket_and_mock_warn(self, sock, server_hostname):
mock_context = Mock()
with patch("warnings.warn") as warn:
ssl_wrap_socket(
ssl_context=mock_context, sock=sock, server_hostname=server_hostname,
)
return mock_context, warn

def test_ssl_wrap_socket_sni_hostname_use_or_warn(self):
"""Test that either an SNI hostname is used or a warning is made."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, "www.google.com")
if util.HAS_SNI:
warn.assert_not_called()
context.wrap_socket.assert_called_once_with(
sock, server_hostname="www.google.com"
)
else:
assert warn.call_count >= 1
warnings = [call[0][1] for call in warn.call_args_list]
assert SNIMissingWarning in warnings
context.wrap_socket.assert_called_once_with(sock)

def test_ssl_wrap_socket_sni_ip_address_no_warn(self):
"""Test that a warning is not made if server_hostname is an IP address."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, "8.8.8.8")
if util.IS_SECURETRANSPORT:
context.wrap_socket.assert_called_once_with(sock, server_hostname="8.8.8.8")
else:
context.wrap_socket.assert_called_once_with(sock)
warn.assert_not_called()

def test_ssl_wrap_socket_sni_none_no_warn(self):
"""Test that a warning is not made if server_hostname is not given."""
sock = object()
context, warn = self._wrap_socket_and_mock_warn(sock, None)
context.wrap_socket.assert_called_once_with(sock)
warn.assert_not_called()