diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d495f51086..add66aadec 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -555,15 +555,15 @@ async def app(scope, receive, send): @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): send_accept_task = asyncio.Event() + disconnect_message = {} async def app(scope, receive, send): - while True: - message = await receive() - if message["type"] == "websocket.connect": - await send_accept_task.wait() - await send({"type": "websocket.accept"}) - elif message["type"] == "websocket.disconnect": - break + nonlocal disconnect_message + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + disconnect_message = await receive() async def websocket_session(uri): async with websockets.client.connect(uri): @@ -577,6 +577,8 @@ async def websocket_session(uri): task.cancel() send_accept_task.set() + assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls, http_protocol_cls ): frames = [] + client_close_connection = asyncio.Event() class App(WebSocketResponse): async def websocket_connect(self, message): @@ -736,7 +739,7 @@ async def websocket_connect(self, message): # Ensure server doesn't start reading frames from read buffer until # after client has sent close frame, but server is still able to # read these frames - await asyncio.sleep(0.2) + await client_close_connection.wait() async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -750,6 +753,7 @@ async def send_text(url): config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): await send_text("ws://127.0.0.1:8000") + client_close_connection.set() assert frames == [b"abc", b"abc", b"abc"] diff --git a/uvicorn/protocols/websockets/ws_impl.py b/uvicorn/protocols/websockets/ws_impl.py new file mode 100644 index 0000000000..b4bf4b8454 --- /dev/null +++ b/uvicorn/protocols/websockets/ws_impl.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import asyncio + +import websockets +import websockets.frames +from websockets.server import ServerConnection + +from uvicorn.config import Config +from uvicorn.server import ServerState + + +class WebSocketProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + _loop: asyncio.AbstractEventLoop | None = None, + ): + self.conn = ServerConnection() + + def data_received(self, data: bytes) -> None: + if data: + self.conn.receive_data(data) + else: + self.conn.receive_eof() + self.handle_events() + + def handle_events(self): + for event in self.conn.events_received(): + if isinstance(event, websockets.frames.Frame): + ... diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a97766ff56..93d1d04836 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,7 +70,7 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect"}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: @@ -267,7 +267,7 @@ async def send(self, message): self.transport.write(output) elif message_type == "websocket.close": - self.queue.put_nowait({"type": "websocket.disconnect", "code": None}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"],