diff --git a/tests/test_trustme.py b/tests/test_trustme.py index 68d61e6c..762b41df 100644 --- a/tests/test_trustme.py +++ b/tests/test_trustme.py @@ -7,6 +7,7 @@ import socket import threading import datetime +import contextlib from concurrent.futures import ThreadPoolExecutor from ipaddress import IPv4Address, IPv6Address, IPv4Network, IPv6Network @@ -254,28 +255,24 @@ def check_connection_end_to_end(wrap_client, wrap_server): # Client side def fake_ssl_client(ca, raw_client_sock, hostname): try: - wrapped_client_sock = wrap_client(ca, raw_client_sock, hostname) - # Send and receive some data to prove the connection is good - wrapped_client_sock.send(b"x") - assert wrapped_client_sock.recv(1) == b"y" + with contextlib.closing(wrap_client(ca, raw_client_sock, hostname)) as wrapped_client_sock: + # Send and receive some data to prove the connection is good + wrapped_client_sock.send(b"x") + assert wrapped_client_sock.recv(1) == b"y" except: # pragma: no cover sys.excepthook(*sys.exc_info()) raise - finally: - raw_client_sock.close() # Server side def fake_ssl_server(server_cert, raw_server_sock): try: - wrapped_server_sock = wrap_server(server_cert, raw_server_sock) - # Prove that we're connected - assert wrapped_server_sock.recv(1) == b"x" - wrapped_server_sock.send(b"y") + with contextlib.closing(wrap_server(server_cert, raw_server_sock)) as wrapped_server_sock: + # Prove that we're connected + assert wrapped_server_sock.recv(1) == b"x" + wrapped_server_sock.send(b"y") except: # pragma: no cover sys.excepthook(*sys.exc_info()) raise - finally: - raw_server_sock.close() def doit(ca, hostname, server_cert): # socketpair and ssl don't work together on py2, because... reasons. @@ -284,14 +281,16 @@ def doit(ca, hostname, server_cert): listener.bind(("127.0.0.1", 0)) listener.listen(1) raw_client_sock = socket.socket() - raw_client_sock.connect(listener.getsockname()) - raw_server_sock, _ = listener.accept() - listener.close() - with ThreadPoolExecutor(2) as tpe: - f1 = tpe.submit(fake_ssl_client, ca, raw_client_sock, hostname) - f2 = tpe.submit(fake_ssl_server, server_cert, raw_server_sock) - f1.result() - f2.result() + with contextlib.closing(socket.socket()) as raw_client_sock: + raw_client_sock.connect(listener.getsockname()) + raw_server_sock, _ = listener.accept() + with contextlib.closing(raw_server_sock): + listener.close() + with ThreadPoolExecutor(2) as tpe: + f1 = tpe.submit(fake_ssl_client, ca, raw_client_sock, hostname) + f2 = tpe.submit(fake_ssl_server, server_cert, raw_server_sock) + f1.result() + f2.result() ca = CA() intermediate_ca = ca.create_child_ca()