diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 366a4f0bdb..23ce96c50b 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -732,6 +732,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ): frames = [] client_close_connection = asyncio.Event() + disconnect_message = {} class App(WebSocketResponse): async def websocket_connect(self, message): @@ -741,6 +742,10 @@ async def websocket_connect(self, message): # read these frames await client_close_connection.wait() + async def websocket_disconnect(self, message): + nonlocal disconnect_message + disconnect_message = message + async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -753,9 +758,11 @@ async def send_text(url): config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): await send_text("ws://127.0.0.1:8000") + await asyncio.sleep(0.1) client_close_connection.set() assert frames == [b"abc", b"abc", b"abc"] + assert disconnect_message == {"type": "websocket.disconnect", "code": 1000} @pytest.mark.anyio