Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SSLTransport to support TLS in TLS connections.
TLS within TLS is not supported easily within the ssl library. The SSLSocket actually takes over the existing socket (https://github.com/python/cpython/blob/master/Lib/ssl.py#L999-L1006) instead of wrapping it entirely. The only way to support to TLS within TLS is with the wrap_bio methods. This commit introduces SSLTransport which wraps a socket in TLS using the provided SSL context. Rather than taking over the socket it uses the wrap_bio methods to perform TLS on top of that socket. Signed-off-by: Jorge Lopez Silva <jalopezsilva@gmail.com>
- Loading branch information
1 parent
cbb6be7
commit 957e172
Showing
2 changed files
with
591 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import ssl | ||
import socket | ||
import io | ||
|
||
|
||
class SSLTransportError(Exception): | ||
"Will wrap any SSL errors received during socket IO." | ||
pass | ||
|
||
|
||
class SSLTransport: | ||
""" | ||
The SSLTransport wraps an existing socket and establishes an SSL connection. | ||
Contrary to Python's implementation of SSLSocket, it allows you to chain | ||
multiple TLS connections together. It's particularly useful if you need to | ||
implement TLS within TLS. | ||
The class supports most of the socket API operations. | ||
""" | ||
|
||
def __init__( | ||
self, socket, ssl_context, suppress_ragged_eofs=True, server_hostname=None | ||
): | ||
""" | ||
Create an SSLTransport around socket using the provided ssl_context. | ||
""" | ||
self.incoming = ssl.MemoryBIO() | ||
self.outgoing = ssl.MemoryBIO() | ||
|
||
self.suppress_ragged_eofs = suppress_ragged_eofs | ||
self.socket = socket | ||
|
||
self.sslobj = ssl_context.wrap_bio( | ||
self.incoming, self.outgoing, server_hostname=server_hostname | ||
) | ||
|
||
# Perform initial handshake. | ||
self._ssl_io_loop(self.sslobj.do_handshake) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, *args): | ||
self.close() | ||
|
||
def fileno(self): | ||
return self.socket.fileno() | ||
|
||
def read(self, len=1024, buffer=None): | ||
return self._wrap_ssl_read(len, buffer) | ||
|
||
def recv(self, len=1024, flags=0): | ||
if flags != 0: | ||
raise ValueError("non-zero flags not allowed in calls to recv") | ||
return self._wrap_ssl_read(len) | ||
|
||
def recv_into(self, buffer, nbytes=None, flags=0): | ||
if flags != 0: | ||
raise ValueError("non-zero flags not allowed in calls to recv_into") | ||
if buffer and (nbytes is None): | ||
nbytes = len(buffer) | ||
elif nbytes is None: | ||
nbytes = 1024 | ||
return self.read(nbytes, buffer) | ||
|
||
def sendall(self, data, flags=0): | ||
if flags != 0: | ||
raise ValueError("non-zero flags not allowed in calls to sendall") | ||
count = 0 | ||
with memoryview(data) as view, view.cast("B") as byte_view: | ||
amount = len(byte_view) | ||
while count < amount: | ||
v = self.send(byte_view[count:]) | ||
count += v | ||
|
||
def send(self, data, flags=0): | ||
if flags != 0: | ||
raise ValueError("non-zero flags not allowed in calls to send") | ||
response = self._ssl_io_loop(self.sslobj.write, data) | ||
return response | ||
|
||
def makefile( | ||
self, mode="r", buffering=None, encoding=None, errors=None, newline=None | ||
): | ||
""" | ||
Python's httpclient uses makefile and buffered io when reading HTTP | ||
messages and we need to support it. | ||
This is unfortunately a copy and paste of socket.py makefile with small | ||
changes to point to the socket directly. | ||
""" | ||
if not set(mode) <= {"r", "w", "b"}: | ||
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,)) | ||
|
||
writing = "w" in mode | ||
reading = "r" in mode or not writing | ||
assert reading or writing | ||
binary = "b" in mode | ||
rawmode = "" | ||
if reading: | ||
rawmode += "r" | ||
if writing: | ||
rawmode += "w" | ||
raw = socket.SocketIO(self, rawmode) | ||
self.socket._io_refs += 1 | ||
if buffering is None: | ||
buffering = -1 | ||
if buffering < 0: | ||
buffering = io.DEFAULT_BUFFER_SIZE | ||
if buffering == 0: | ||
if not binary: | ||
raise ValueError("unbuffered streams must be binary") | ||
return raw | ||
if reading and writing: | ||
buffer = io.BufferedRWPair(raw, raw, buffering) | ||
elif reading: | ||
buffer = io.BufferedReader(raw, buffering) | ||
else: | ||
assert writing | ||
buffer = io.BufferedWriter(raw, buffering) | ||
if binary: | ||
return buffer | ||
text = io.TextIOWrapper(buffer, encoding, errors, newline) | ||
text.mode = mode | ||
return text | ||
|
||
def unwrap(self): | ||
self._ssl_io_loop(self.sslobj.unwrap) | ||
|
||
def close(self): | ||
self.socket.close() | ||
|
||
def getpeercert(self): | ||
return self.sslobj.getpeercert() | ||
|
||
def version(self): | ||
return self.sslobj.version() | ||
|
||
def cipher(self): | ||
return self.sslobj.cipher() | ||
|
||
def selected_alpn_protocol(self): | ||
return self.sslobj.selected_alpn_protocol() | ||
|
||
def selected_npn_protocol(self): | ||
return self.sslobj.selected_npn_protocol() | ||
|
||
def shared_ciphers(self): | ||
return self.sslobj.shared_ciphers() | ||
|
||
def compression(self): | ||
return self.sslobj.compression() | ||
|
||
def settimeout(self, value): | ||
self.socket.settimeout(value) | ||
|
||
def gettimeout(self): | ||
return self.socket.gettimeout() | ||
|
||
def _decref_socketios(self): | ||
self.socket._decref_socketios() | ||
|
||
def _wrap_ssl_read(self, len, buffer=None): | ||
response = None | ||
try: | ||
response = self._ssl_io_loop(self.sslobj.read, len, buffer) | ||
except SSLTransportError as e: | ||
if e.__context__.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: | ||
response = 0 # eof, return 0. | ||
else: | ||
raise | ||
return response | ||
|
||
def _ssl_io_loop(self, func, *args): | ||
""" Performs an I/O loop between incoming/outgoing and the socket.""" | ||
should_loop = True | ||
|
||
while should_loop: | ||
errno = None | ||
try: | ||
ret = func(*args) | ||
except ssl.SSLError as e: | ||
if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): | ||
raise SSLTransportError(e) | ||
errno = e.errno | ||
|
||
buf = self.outgoing.read() | ||
self.socket.sendall(buf) | ||
|
||
if errno is None: | ||
should_loop = False | ||
elif errno == ssl.SSL_ERROR_WANT_READ: | ||
buf = self.socket.recv(4096) | ||
if buf: | ||
self.incoming.write(buf) | ||
else: | ||
self.incoming.write_eof() | ||
return ret |
Oops, something went wrong.