diff --git a/src/urllib3/contrib/ssl.py b/src/urllib3/contrib/ssl.py new file mode 100644 index 0000000000..c289b1c24d --- /dev/null +++ b/src/urllib3/contrib/ssl.py @@ -0,0 +1,199 @@ +import ssl +import socket +import io + + +class SSLTransportError(Exception): + "Will wrap any SSL errors received during socket IO." + pass + + +class SSLTransport: + """ + The SSLTransport wraps an existing socket and establishes an SSL connection. + + Contrary to Python's implementation of SSLSocket, it allows you to chain + multiple TLS connections together. It's particularly useful if you need to + implement TLS within TLS. + + The class supports most of the socket API operations. + """ + + def __init__( + self, socket, ssl_context, suppress_ragged_eofs=True, server_hostname=None + ): + """ + Create an SSLTransport around socket using the provided ssl_context. + """ + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + + self.suppress_ragged_eofs = suppress_ragged_eofs + self.socket = socket + + self.sslobj = ssl_context.wrap_bio( + self.incoming, self.outgoing, server_hostname=server_hostname + ) + + # Perform initial handshake. + self._ssl_io_loop(self.sslobj.do_handshake) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def fileno(self): + return self.socket.fileno() + + def read(self, len=1024, buffer=None): + return self._wrap_ssl_read(len, buffer) + + def recv(self, len=1024, flags=0): + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv") + return self._wrap_ssl_read(len) + + def recv_into(self, buffer, nbytes=None, flags=0): + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv_into") + if buffer and (nbytes is None): + nbytes = len(buffer) + elif nbytes is None: + nbytes = 1024 + return self.read(nbytes, buffer) + + def sendall(self, data, flags=0): + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to sendall") + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v + + def send(self, data, flags=0): + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to send") + response = self._ssl_io_loop(self.sslobj.write, data) + return response + + def makefile( + self, mode="r", buffering=None, encoding=None, errors=None, newline=None + ): + """ + Python's httpclient uses makefile and buffered io when reading HTTP + messages and we need to support it. + + This is unfortunately a copy and paste of socket.py makefile with small + changes to point to the socket directly. + """ + if not set(mode) <= {"r", "w", "b"}: + raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) + + writing = "w" in mode + reading = "r" in mode or not writing + assert reading or writing + binary = "b" in mode + rawmode = "" + if reading: + rawmode += "r" + if writing: + rawmode += "w" + raw = socket.SocketIO(self, rawmode) + self.socket._io_refs += 1 + if buffering is None: + buffering = -1 + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + if not binary: + raise ValueError("unbuffered streams must be binary") + return raw + if reading and writing: + buffer = io.BufferedRWPair(raw, raw, buffering) + elif reading: + buffer = io.BufferedReader(raw, buffering) + else: + assert writing + buffer = io.BufferedWriter(raw, buffering) + if binary: + return buffer + text = io.TextIOWrapper(buffer, encoding, errors, newline) + text.mode = mode + return text + + def unwrap(self): + self._ssl_io_loop(self.sslobj.unwrap) + + def close(self): + self.socket.close() + + def getpeercert(self): + return self.sslobj.getpeercert() + + def version(self): + return self.sslobj.version() + + def cipher(self): + return self.sslobj.cipher() + + def selected_alpn_protocol(self): + return self.sslobj.selected_alpn_protocol() + + def selected_npn_protocol(self): + return self.sslobj.selected_npn_protocol() + + def shared_ciphers(self): + return self.sslobj.shared_ciphers() + + def compression(self): + return self.sslobj.compression() + + def settimeout(self, value): + self.socket.settimeout(value) + + def gettimeout(self): + return self.socket.gettimeout() + + def _decref_socketios(self): + self.socket._decref_socketios() + + def _wrap_ssl_read(self, len, buffer=None): + response = None + try: + response = self._ssl_io_loop(self.sslobj.read, len, buffer) + except SSLTransportError as e: + if e.__context__.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: + response = 0 # eof, return 0. + else: + raise + return response + + def _ssl_io_loop(self, func, *args): + """ Performs an I/O loop between incoming/outgoing and the socket.""" + should_loop = True + + while should_loop: + errno = None + try: + ret = func(*args) + except ssl.SSLError as e: + if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + raise SSLTransportError(e) + errno = e.errno + + buf = self.outgoing.read() + self.socket.sendall(buf) + + if errno is None: + should_loop = False + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = self.socket.recv(4096) + if buf: + self.incoming.write(buf) + else: + self.incoming.write_eof() + return ret diff --git a/test/contrib/test_ssltransport.py b/test/contrib/test_ssltransport.py new file mode 100644 index 0000000000..65063f0d92 --- /dev/null +++ b/test/contrib/test_ssltransport.py @@ -0,0 +1,393 @@ +from dummyserver.testcase import SocketDummyServerTestCase, consume_socket +from dummyserver.server import ( + DEFAULT_CERTS, + DEFAULT_CA, +) + +from urllib3.contrib.ssl import SSLTransport, SSLTransportError + +import select +import pytest +import socket +import ssl +import sys + + +def get_server_client_ssl_contexts(): + + if hasattr(ssl, "PROTOCOL_TLS_SERVER"): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + else: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS) + server_context.load_cert_chain(DEFAULT_CERTS["certfile"], DEFAULT_CERTS["keyfile"]) + + if hasattr(ssl, "PROTOCOL_TLS_CLIENT"): + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + else: + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.check_hostname = True + + client_context.load_verify_locations(DEFAULT_CA) + return server_context, client_context + + +@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python3.5 or higher") +class SingleTLSLayerTestCase(SocketDummyServerTestCase): + """ + Uses the SocketDummyServer to validate a single TLS layer can be + established through the SSLTransport. + """ + + @classmethod + def setup_class(cls): + cls.server_context, cls.client_context = get_server_client_ssl_contexts() + + def start_dummy_server(self): + def socket_handler(listener): + sock = listener.accept()[0] + with self.server_context.wrap_socket(sock, server_side=True) as ssock: + request = consume_socket(ssock) + assert request is not None + assert "www.testing.com" in request.decode("utf-8") + + response = b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n" + ssock.send(response) + + self._start_server(socket_handler) + + def test_start_closed_socket(self): + """ Errors generated from an unconnected socket should bubble up.""" + sock = socket.socket(socket.AF_INET) + context = ssl.create_default_context() + sock.close() + with pytest.raises(OSError): + SSLTransport(sock, context) + + def test_close_after_handshake(self): + """ Socket errors should be bubbled up """ + self.start_dummy_server() + + sock = socket.create_connection((self.host, self.port)) + with SSLTransport( + sock, self.client_context, server_hostname="localhost" + ) as ssock: + ssock.close() + with pytest.raises(OSError): + ssock.send(b"blaaargh") + + def test_wrap_existing_socket(self): + """ Validates a single TLS layer can be established. """ + self.start_dummy_server() + + sock = socket.create_connection((self.host, self.port)) + with SSLTransport( + sock, self.client_context, server_hostname="localhost" + ) as ssock: + assert ssock.version() is not None + ssock.send( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + response = consume_socket(ssock) + assert response is not None + + def test_ssl_object_attributes(self): + """ Ensures common ssl attributes are exposed """ + self.start_dummy_server() + + sock = socket.create_connection((self.host, self.port)) + with SSLTransport( + sock, self.client_context, server_hostname="localhost" + ) as ssock: + assert ssock.cipher() is not None + assert ssock.selected_alpn_protocol() is None + assert ssock.selected_npn_protocol() is None + assert ssock.shared_ciphers() is not None + assert ssock.compression() is None + assert ssock.getpeercert() is not None + + ssock.send( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + response = consume_socket(ssock) + assert response is not None + + def test_socket_object_attributes(self): + """ Ensures common socket attributes are exposed """ + self.start_dummy_server() + + sock = socket.create_connection((self.host, self.port)) + with SSLTransport( + sock, self.client_context, server_hostname="localhost" + ) as ssock: + assert ssock.fileno() is not None + test_timeout = 10 + ssock.settimeout(test_timeout) + assert ssock.gettimeout() == test_timeout + ssock.send( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + response = consume_socket(ssock) + assert response is not None + + +class SocketProxyDummyServer(SocketDummyServerTestCase): + """ + Simulates a proxy that performs a simple I/O loop on client/server + socket. + """ + + def __init__(self, destination_server_host, destination_server_port): + self.destination_server_host = destination_server_host + self.destination_server_port = destination_server_port + self.server_context, self.client_context = get_server_client_ssl_contexts() + + def start_proxy_handler(self): + """ + Socket handler for the proxy. Terminates the first TLS layer and tunnels + any bytes needed for client <-> server communicatin. + """ + + def proxy_handler(listener): + sock = listener.accept()[0] + with self.server_context.wrap_socket(sock, server_side=True) as client_sock: + upstream_sock = socket.create_connection( + (self.destination_server_host, self.destination_server_port) + ) + self._read_write_loop(client_sock, upstream_sock) + upstream_sock.close() + + self._start_server(proxy_handler) + + def _read_write_loop(self, client_sock, server_sock, chunks=65536): + inputs = [client_sock, server_sock] + output = [client_sock, server_sock] + + while inputs: + readable, writable, exception = select.select(inputs, output, inputs) + + if exception: + # Error ocurred with either of the sockets, time to + # wrap up, parent func will close sockets. + break + + for s in readable: + read_socket, write_socket = None, None + if s == client_sock: + read_socket = client_sock + write_socket = server_sock + else: + read_socket = server_sock + write_socket = client_sock + + # Ensure buffer is not full before writting + if write_socket in writable: + try: + b = read_socket.recv(chunks) + write_socket.send(b) + except ssl.SSLEOFError: + # It's possible, depending on shutdown order, that we'll + # try to use a socket that was closed between select + # calls. + return + + +@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python3.5 or higher") +class TlsInTlsTestCase(SocketDummyServerTestCase): + """ + Creates a TLS in TLS tunnel by chaining a 'SocketProxyDummyServer' and a + `SocketDummyServerTestCase`. + + Client will first connect to the proxy, who will then proxy any bytes send + to the destination server. First TLS layer terminates at the proxy, second + TLS layer terminates at the destination server. + """ + + @classmethod + def setup_class(cls): + cls.server_context, cls.client_context = get_server_client_ssl_contexts() + + @classmethod + def start_proxy_server(cls): + # Proxy server will handle the first TLS connection and create a + # connection to the destination server. + cls.proxy_server = SocketProxyDummyServer(cls.host, cls.port) + cls.proxy_server.start_proxy_handler() + + @classmethod + def teardown_class(cls): + if hasattr(cls, "proxy_server"): + cls.proxy_server.teardown_class() + super(TlsInTlsTestCase, cls).teardown_class() + + @classmethod + def start_destination_server(cls): + """ + Socket handler for the destination_server. Terminates the second TLS + layer and send a basic HTTP response. + """ + + def socket_handler(listener): + sock = listener.accept()[0] + with cls.server_context.wrap_socket(sock, server_side=True) as ssock: + request = consume_socket(ssock) + assert request is not None + assert "www.testing.com" in request.decode("utf-8") + + response = b"HTTP/1.1 200 OK\r\n" b"Content-Length: 0\r\n" b"\r\n" + ssock.send(response) + ssock.close() + + cls._start_server(socket_handler) + + def test_tls_in_tls_tunnel(self): + """ + Basic communication over the TLS in TLS tunnel. + """ + self.start_destination_server() + self.start_proxy_server() + + sock = socket.create_connection( + (self.proxy_server.host, self.proxy_server.port) + ) + with self.client_context.wrap_socket( + sock, server_hostname="localhost" + ) as proxy_sock: + with SSLTransport( + proxy_sock, self.client_context, server_hostname="localhost" + ) as destination_sock: + assert destination_sock.version() is not None + destination_sock.send( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + response = consume_socket(destination_sock) + assert response is not None + assert "200" in response.decode("utf-8") + + def test_wrong_sni_hint(self): + """ + Provides a wrong sni hint to validate an exception is thrown. + """ + self.start_destination_server() + self.start_proxy_server() + + sock = socket.create_connection( + (self.proxy_server.host, self.proxy_server.port) + ) + with self.client_context.wrap_socket( + sock, server_hostname="localhost" + ) as proxy_sock: + with pytest.raises(Exception) as e: + SSLTransport( + proxy_sock, self.client_context, server_hostname="veryverywrong" + ) + # Accommodate different python3 versions + assert e.type in [SSLTransportError, ssl.CertificateError] + + def test_tls_in_tls_makefile_rw_binary(self): + """ + Uses makefile with read, write and binary modes. + """ + self.start_destination_server() + self.start_proxy_server() + + sock = socket.create_connection( + (self.proxy_server.host, self.proxy_server.port) + ) + with self.client_context.wrap_socket( + sock, server_hostname="localhost" + ) as proxy_sock: + with SSLTransport( + proxy_sock, self.client_context, server_hostname="localhost" + ) as destination_sock: + + file = destination_sock.makefile("rwb") + + file.write( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + file.flush() + + response = bytearray(65536) + wrote = file.readinto(response) + assert wrote is not None + assert response is not None + assert "200" in response.decode("utf-8") + file.close() + + def test_tls_in_tls_makefile_rw_text(self): + """ + Creates a separate buffer for reading and writing using text mode and + utf-8 encoding. + """ + self.start_destination_server() + self.start_proxy_server() + + sock = socket.create_connection( + (self.proxy_server.host, self.proxy_server.port) + ) + with self.client_context.wrap_socket( + sock, server_hostname="localhost" + ) as proxy_sock: + with SSLTransport( + proxy_sock, self.client_context, server_hostname="localhost" + ) as destination_sock: + + read = destination_sock.makefile("r", encoding="utf-8") + write = destination_sock.makefile("w", encoding="utf-8") + + write.write( + "GET http://www.testing.com/ HTTP/1.1\r\n" + "Host: www.testing.com\r\n" + "User-Agent: awesome-test\r\n" + "\r\n" + ) + write.flush() + + response = read.read() + assert response is not None + assert "200" in response + + def test_tls_in_tls_recv_into_sendall(self): + """ + Valides recv_into and sendall also work as expected. Other tests are + using recv/send. + """ + self.start_destination_server() + self.start_proxy_server() + + sock = socket.create_connection( + (self.proxy_server.host, self.proxy_server.port) + ) + with self.client_context.wrap_socket( + sock, server_hostname="localhost" + ) as proxy_sock: + with SSLTransport( + proxy_sock, self.client_context, server_hostname="localhost" + ) as destination_sock: + + destination_sock.sendall( + b"GET http://www.testing.com/ HTTP/1.1\r\n" + b"Host: www.testing.com\r\n" + b"User-Agent: awesome-test\r\n" + b"\r\n" + ) + response = bytearray(65536) + destination_sock.recv_into(response) + assert response is not None + assert "200" in response.decode("utf-8")