diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index a5fe93d5b5..21905e19dd 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -528,7 +528,6 @@ async def app(scope, receive, send): while True: message = await receive() if message["type"] == "websocket.connect": - print("accepted") await send({"type": "websocket.accept"}) elif message["type"] == "websocket.disconnect": break @@ -551,6 +550,40 @@ async def app(scope, receive, send): assert got_disconnect_event_before_shutdown is True +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): + send_accept_task = asyncio.Event() + + async def app(scope, receive, send): + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + ) + + async def websocket_session(uri): + async with websockets.client.connect(uri): + while True: + await asyncio.sleep(0.1) + + async with run_server(config): + task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000")) + await asyncio.sleep(0.1) + task.cancel() + send_accept_task.set() + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/tests/utils.py b/tests/utils.py index 909064651d..1eba7733c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import asyncio import os +import traceback from contextlib import asynccontextmanager, contextmanager from pathlib import Path @@ -13,6 +14,8 @@ async def run_server(config: Config, sockets=None): await asyncio.sleep(0.1) try: yield server + except BaseException: + traceback.print_exc() finally: await server.shutdown() task.cancel()