diff --git a/eventlet/green/ssl.py b/eventlet/green/ssl.py index 10cff21d46..c49e872a10 100644 --- a/eventlet/green/ssl.py +++ b/eventlet/green/ssl.py @@ -404,14 +404,11 @@ def accept(self): new_ssl = type(self)( newsock, - keyfile=self.keyfile, - certfile=self.certfile, server_side=True, - cert_reqs=self.cert_reqs, - ssl_version=self.ssl_version, - ca_certs=self.ca_certs, do_handshake_on_connect=False, - suppress_ragged_eofs=self.suppress_ragged_eofs) + suppress_ragged_eofs=self.suppress_ragged_eofs, + _context=self._context, + ) return (new_ssl, addr) def dup(self): diff --git a/tests/ssl_test.py b/tests/ssl_test.py index d3e378068b..ea0cc858e8 100644 --- a/tests/ssl_test.py +++ b/tests/ssl_test.py @@ -1,4 +1,5 @@ import contextlib +import random import socket import warnings @@ -325,3 +326,48 @@ def accept(listener): server_to_client.close() listener.close() + + def test_context_wrapped_accept(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.load_cert_chain(tests.certificate_file, tests.private_key_file) + expected = "success:{}".format(random.random()).encode() + + def client(addr): + client_tls = ssl.wrap_socket( + eventlet.connect(addr), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=tests.certificate_file, + ) + client_tls.send(expected) + + server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_sock.bind(('localhost', 0)) + server_sock.listen(1) + eventlet.spawn(client, server_sock.getsockname()) + server_tls = context.wrap_socket(server_sock, server_side=True) + peer, _ = server_tls.accept() + assert peer.recv(64) == expected + peer.close() + + def test_explicit_keys_accept(self): + expected = "success:{}".format(random.random()).encode() + + def client(addr): + client_tls = ssl.wrap_socket( + eventlet.connect(addr), + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=tests.certificate_file, + ) + client_tls.send(expected) + + server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_sock.bind(('localhost', 0)) + server_sock.listen(1) + eventlet.spawn(client, server_sock.getsockname()) + server_tls = ssl.wrap_socket( + server_sock, server_side=True, + keyfile=tests.private_key_file, certfile=tests.certificate_file, + ) + peer, _ = server_tls.accept() + assert peer.recv(64) == expected + peer.close()