From 41156aaa4d6e442b72793c643c9ba8ac3c765c34 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 13:24:21 +0100 Subject: [PATCH] Use correct WebSocket error codes (#1753) --- tests/protocols/test_websocket.py | 30 ++++++++++++------- .../protocols/websockets/websockets_impl.py | 12 ++++++-- uvicorn/protocols/websockets/wsproto_impl.py | 8 +++-- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d495f5108..7babe462f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -553,22 +553,22 @@ async def app(scope, receive, send): @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) -async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): +async def test_connection_lost_before_handshake_complete( + ws_protocol_cls, http_protocol_cls +): send_accept_task = asyncio.Event() + disconnect_message = {} async def app(scope, receive, send): - while True: - message = await receive() - if message["type"] == "websocket.connect": - await send_accept_task.wait() - await send({"type": "websocket.accept"}) - elif message["type"] == "websocket.disconnect": - break + nonlocal disconnect_message + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + disconnect_message = await receive() async def websocket_session(uri): - async with websockets.client.connect(uri): - while True: - await asyncio.sleep(0.1) + await websockets.client.connect(uri) config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): @@ -577,6 +577,8 @@ async def websocket_session(uri): task.cancel() send_accept_task.set() + assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls, http_protocol_cls ): frames = [] + disconnect_message = {} class App(WebSocketResponse): async def websocket_connect(self, message): @@ -738,6 +741,10 @@ async def websocket_connect(self, message): # read these frames await asyncio.sleep(0.2) + async def websocket_disconnect(self, message): + nonlocal disconnect_message + disconnect_message = message + async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -752,6 +759,7 @@ async def send_text(url): await send_text("ws://127.0.0.1:8000") assert frames == [b"abc", b"abc", b"abc"] + assert disconnect_message == {"type": "websocket.disconnect", "code": 1000} @pytest.mark.anyio diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 563a45eaa..cf1ee95fa 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -88,6 +88,7 @@ def __init__( self.closed_event = asyncio.Event() self.initial_response: Optional[HTTPResponse] = None self.connect_sent = False + self.lost_connection_before_handshake = False self.accepted_subprotocol: Optional[Subprotocol] = None self.transfer_data_task: asyncio.Task = None # type: ignore[assignment] @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.lost_connection_before_handshake = ( + not self.handshake_completed_event.is_set() + ) self.handshake_completed_event.set() super().connection_lost(exc) if exc is None: @@ -335,11 +339,13 @@ async def asgi_receive( await self.handshake_completed_event.wait() - if self.closed_event.is_set(): - # If client disconnected, use WebSocketServerProtocol.close_code property. + if self.lost_connection_before_handshake: # If the handshake failed or the app closed before handshake completion, # use 1006 Abnormal Closure. - return {"type": "websocket.disconnect", "code": self.close_code or 1006} + return {"type": "websocket.disconnect", "code": 1006} + + if self.closed_event.is_set(): + return {"type": "websocket.disconnect", "code": 1005} try: data = await self.recv() diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 97b31c196..3241a1551 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,13 +70,15 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect"}) + code = 1005 if self.handshake_complete else 1006 + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.handshake_complete = True if exc is None: self.transport.close() @@ -232,13 +234,13 @@ async def send(self, message): self.scope["client"], get_path_with_query_string(self.scope), ) - self.handshake_complete = True subprotocol = message.get("subprotocol") extra_headers = self.default_headers + list(message.get("headers", [])) extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) if not self.transport.is_closing(): + self.handshake_complete = True output = self.conn.send( wsproto.events.AcceptConnection( subprotocol=subprotocol, @@ -249,7 +251,7 @@ async def send(self, message): self.transport.write(output) elif message_type == "websocket.close": - self.queue.put_nowait({"type": "websocket.disconnect", "code": None}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"],