Skip to content

Commit

Permalink
Fix tests that leak threads (pytest 8) (#3358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ecerulm committed Mar 26, 2024
1 parent ea443d2 commit 8c20886
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 29 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
@@ -1,7 +1,7 @@
h2==4.1.0
coverage==7.4.1
PySocks==1.7.1
pytest==7.4.4
pytest==8.0.2
pytest-timeout==2.1.0
pyOpenSSL==24.0.0
idna==3.4
Expand Down
2 changes: 2 additions & 0 deletions dummyserver/socketserver.py
Expand Up @@ -108,13 +108,15 @@ def __init__(
socket_handler: typing.Callable[[socket.socket], None],
host: str = "localhost",
ready_event: threading.Event | None = None,
quit_event: threading.Event | None = None,
) -> None:
super().__init__()
self.daemon = True

self.socket_handler = socket_handler
self.host = host
self.ready_event = ready_event
self.quit_event = quit_event

def _start_server(self) -> None:
if self.USE_IPV6:
Expand Down
76 changes: 64 additions & 12 deletions dummyserver/testcase.py
Expand Up @@ -5,6 +5,7 @@
import ssl
import threading
import typing
from test import LONG_TIMEOUT

import hypercorn
import pytest
Expand All @@ -19,11 +20,19 @@


def consume_socket(
sock: SSLTransport | socket.socket, chunks: int = 65536
sock: SSLTransport | socket.socket,
chunks: int = 65536,
quit_event: threading.Event | None = None,
) -> bytearray:
consumed = bytearray()
sock.settimeout(LONG_TIMEOUT)
while True:
b = sock.recv(chunks)
if quit_event and quit_event.is_set():
break
try:
b = sock.recv(chunks)
except (TimeoutError, socket.timeout):
continue
assert isinstance(b, bytes)
consumed += b
if b.endswith(b"\r\n\r\n"):
Expand Down Expand Up @@ -57,11 +66,16 @@ class SocketDummyServerTestCase:

@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.start()
ready_event.wait(5)
Expand All @@ -71,23 +85,41 @@ def _start_server(

@classmethod
def start_response_handler(
cls, response: bytes, num: int = 1, block_send: threading.Event | None = None
cls,
response: bytes,
num: int = 1,
block_send: threading.Event | None = None,
) -> threading.Event:
ready_event = threading.Event()
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(num):
ready_event.set()

sock = listener.accept()[0]
consume_socket(sock)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
consume_socket(sock, quit_event=quit_event)
if quit_event.is_set():
sock.close()
return
if block_send:
block_send.wait()
while not block_send.wait(LONG_TIMEOUT):
if quit_event.is_set():
sock.close()
return
block_send.clear()
sock.send(response)
sock.close()

cls._start_server(socket_handler)
cls._start_server(socket_handler, quit_event=quit_event)
return ready_event

@classmethod
Expand All @@ -100,10 +132,25 @@ def start_basic_handler(
block_send,
)

@staticmethod
def quit_server_thread(server_thread: SocketServerThread) -> None:
if server_thread.quit_event:
server_thread.quit_event.set()
# in principle the maximum time that the thread can take to notice
# the quit_event is LONG_TIMEOUT and the thread should terminate
# shortly after that, we give 5 seconds leeway just in case
server_thread.join(LONG_TIMEOUT * 2 + 5.0)
if server_thread.is_alive():
raise Exception("server_thread did not exit")

@classmethod
def teardown_class(cls) -> None:
if hasattr(cls, "server_thread"):
cls.server_thread.join(0.1)
cls.quit_server_thread(cls.server_thread)

def teardown_method(self) -> None:
if hasattr(self, "server_thread"):
self.quit_server_thread(self.server_thread)

def assert_header_received(
self,
Expand All @@ -128,11 +175,16 @@ def assert_header_received(
class IPV4SocketDummyServerTestCase(SocketDummyServerTestCase):
@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.USE_IPV6 = False
cls.server_thread.start()
Expand Down
18 changes: 14 additions & 4 deletions test/test_ssltransport.py
Expand Up @@ -4,6 +4,7 @@
import select
import socket
import ssl
import threading
import typing
from unittest import mock

Expand Down Expand Up @@ -108,20 +109,29 @@ def setup_class(cls) -> None:
cls.server_context, cls.client_context = server_client_ssl_contexts()

def start_dummy_server(
self, handler: typing.Callable[[socket.socket], None] | None = None
self,
handler: typing.Callable[[socket.socket], None] | None = None,
validate: bool = True,
) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
sock = listener.accept()[0]
try:
with self.server_context.wrap_socket(sock, server_side=True) as ssock:
request = consume_socket(ssock)
request = consume_socket(
ssock,
quit_event=quit_event,
)
if not validate:
return
validate_request(request)
ssock.send(sample_response())
except (ConnectionAbortedError, ConnectionResetError):
return

chosen_handler = handler if handler else socket_handler
self._start_server(chosen_handler)
self._start_server(chosen_handler, quit_event=quit_event)

@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_start_closed_socket(self) -> None:
Expand All @@ -135,7 +145,7 @@ def test_start_closed_socket(self) -> None:
@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_close_after_handshake(self) -> None:
"""Socket errors should be bubbled up"""
self.start_dummy_server()
self.start_dummy_server(validate=False)

sock = socket.create_connection((self.host, self.port))
with SSLTransport(
Expand Down
59 changes: 47 additions & 12 deletions test/with_dummyserver/test_socketlevel.py
Expand Up @@ -12,6 +12,7 @@
import socket
import ssl
import tempfile
import threading
import typing
import zlib
from collections import OrderedDict
Expand Down Expand Up @@ -955,7 +956,11 @@ def socket_handler(listener: socket.socket) -> None:
assert response.connection is None

def test_socket_close_socket_then_file(self) -> None:
def consume_ssl_socket(listener: socket.socket) -> None:
quit_event = threading.Event()

def consume_ssl_socket(
listener: socket.socket,
) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
sock,
Expand All @@ -964,11 +969,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand All @@ -983,6 +988,8 @@ def consume_ssl_socket(listener: socket.socket) -> None:
assert ssl_sock.fileno() == -1

def test_socket_close_stays_open_with_makefile_open(self) -> None:
quit_event = threading.Event()

def consume_ssl_socket(listener: socket.socket) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
Expand All @@ -992,11 +999,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand Down Expand Up @@ -2232,11 +2239,28 @@ def socket_handler(listener: socket.socket) -> None:

class TestMultipartResponse(SocketDummyServerTestCase):
def test_multipart_assert_header_parsing_no_defects(self) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(2):
sock = listener.accept()[0]
while not sock.recv(65536).endswith(b"\r\n\r\n"):
pass
listener.settimeout(LONG_TIMEOUT)

while True:
if quit_event and quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue

sock.settimeout(LONG_TIMEOUT)
while True:
if quit_event and quit_event.is_set():
sock.close()
return
if sock.recv(65536).endswith(b"\r\n\r\n"):
break

sock.sendall(
b"HTTP/1.1 404 Not Found\r\n"
Expand All @@ -2252,7 +2276,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)
from urllib3.connectionpool import log

with mock.patch.object(log, "warning") as log_warning:
Expand Down Expand Up @@ -2308,15 +2332,26 @@ def socket_handler(listener: socket.socket) -> None:
def test_chunked_specified(
self, method: str, chunked: bool, body_type: str
) -> None:
quit_event = threading.Event()
buffer = bytearray()
expected_bytes = b"\r\n\r\na\r\nxxxxxxxxxx\r\n0\r\n\r\n"

def socket_handler(listener: socket.socket) -> None:
nonlocal buffer
sock = listener.accept()[0]
sock.settimeout(0)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
sock.settimeout(LONG_TIMEOUT)

while expected_bytes not in buffer:
if quit_event.is_set():
return
with contextlib.suppress(BlockingIOError):
buffer += sock.recv(65536)

Expand All @@ -2327,7 +2362,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)

body: typing.Any
if body_type == "generator":
Expand Down

0 comments on commit 8c20886

Please sign in to comment.