From 71b0541b48a5c64cccfc6fc332d634ebbd29092f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 08:39:37 +0100 Subject: [PATCH 1/4] Use correct WebSocket error codes --- tests/protocols/test_websocket.py | 20 ++++++++++++-------- uvicorn/protocols/websockets/wsproto_impl.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d495f5108..add66aade 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -555,15 +555,15 @@ async def app(scope, receive, send): @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): send_accept_task = asyncio.Event() + disconnect_message = {} async def app(scope, receive, send): - while True: - message = await receive() - if message["type"] == "websocket.connect": - await send_accept_task.wait() - await send({"type": "websocket.accept"}) - elif message["type"] == "websocket.disconnect": - break + nonlocal disconnect_message + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + disconnect_message = await receive() async def websocket_session(uri): async with websockets.client.connect(uri): @@ -577,6 +577,8 @@ async def websocket_session(uri): task.cancel() send_accept_task.set() + assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -729,6 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls, http_protocol_cls ): frames = [] + client_close_connection = asyncio.Event() class App(WebSocketResponse): async def websocket_connect(self, message): @@ -736,7 +739,7 @@ async def websocket_connect(self, message): # Ensure server doesn't start reading frames from read buffer until # after client has sent close frame, but server is still able to # read these frames - await asyncio.sleep(0.2) + await client_close_connection.wait() async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -750,6 +753,7 @@ async def send_text(url): config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): await send_text("ws://127.0.0.1:8000") + client_close_connection.set() assert frames == [b"abc", b"abc", b"abc"] diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a97766ff5..93d1d0483 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,7 +70,7 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect"}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: @@ -267,7 +267,7 @@ async def send(self, message): self.transport.write(output) elif message_type == "websocket.close": - self.queue.put_nowait({"type": "websocket.disconnect", "code": None}) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"], From 91b0633fbf4d7cf6215d859acd808c37b5814bda Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 09:38:18 +0100 Subject: [PATCH 2/4] Fix all codes --- tests/protocols/test_websocket.py | 4 +++- uvicorn/protocols/websockets/websockets_impl.py | 12 +++++++++--- uvicorn/protocols/websockets/wsproto_impl.py | 6 ++++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index add66aade..70ced34db 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -553,7 +553,9 @@ async def app(scope, receive, send): @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) -async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): +async def test_connection_lost_before_handshake_complete( + ws_protocol_cls, http_protocol_cls +): send_accept_task = asyncio.Event() disconnect_message = {} diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 87e7baa36..08a47b8ac 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -88,6 +88,7 @@ def __init__( self.closed_event = asyncio.Event() self.initial_response: Optional[HTTPResponse] = None self.connect_sent = False + self.lost_connection_before_handshake = False self.accepted_subprotocol: Optional[Subprotocol] = None self.transfer_data_task: asyncio.Task = None # type: ignore[assignment] @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.lost_connection_before_handshake = ( + not self.handshake_completed_event.is_set() + ) self.handshake_completed_event.set() super().connection_lost(exc) if exc is None: @@ -335,11 +339,13 @@ async def asgi_receive( await self.handshake_completed_event.wait() - if self.closed_event.is_set(): - # If client disconnected, use WebSocketServerProtocol.close_code property. + if self.lost_connection_before_handshake: # If the handshake failed or the app closed before handshake completion, # use 1006 Abnormal Closure. - return {"type": "websocket.disconnect", "code": self.close_code or 1006} + return {"type": "websocket.disconnect", "code": 1006} + + if self.closed_event.is_set(): + return {"type": "websocket.disconnect", "code": 1005} try: data = await self.recv() diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 93d1d0483..9a64ffa22 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,13 +70,15 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005}) + code = 1005 if self.handshake_complete else 1006 + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.handshake_complete = True if exc is None: self.transport.close() @@ -250,13 +252,13 @@ async def send(self, message): self.scope["client"], get_path_with_query_string(self.scope), ) - self.handshake_complete = True subprotocol = message.get("subprotocol") extra_headers = self.default_headers + list(message.get("headers", [])) extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) if not self.transport.is_closing(): + self.handshake_complete = True output = self.conn.send( wsproto.events.AcceptConnection( subprotocol=subprotocol, From e5626330b566801675f10a9a993b595d1d966427 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 09:49:47 +0100 Subject: [PATCH 3/4] Remove uncovered lines on tests --- tests/protocols/test_websocket.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 70ced34db..366a4f0bd 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -568,9 +568,7 @@ async def app(scope, receive, send): disconnect_message = await receive() async def websocket_session(uri): - async with websockets.client.connect(uri): - while True: - await asyncio.sleep(0.1) + await websockets.client.connect(uri) config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): From 3f91b005736b7cbb31e57af8bdd8eec4d46931b6 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 12:36:28 +0100 Subject: [PATCH 4/4] Check the disconncet code --- tests/protocols/test_websocket.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 366a4f0bd..7babe462f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -731,7 +731,7 @@ async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls, http_protocol_cls ): frames = [] - client_close_connection = asyncio.Event() + disconnect_message = {} class App(WebSocketResponse): async def websocket_connect(self, message): @@ -739,7 +739,11 @@ async def websocket_connect(self, message): # Ensure server doesn't start reading frames from read buffer until # after client has sent close frame, but server is still able to # read these frames - await client_close_connection.wait() + await asyncio.sleep(0.2) + + async def websocket_disconnect(self, message): + nonlocal disconnect_message + disconnect_message = message async def websocket_receive(self, message): frames.append(message.get("bytes")) @@ -753,9 +757,9 @@ async def send_text(url): config = Config(app=App, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") async with run_server(config): await send_text("ws://127.0.0.1:8000") - client_close_connection.set() assert frames == [b"abc", b"abc", b"abc"] + assert disconnect_message == {"type": "websocket.disconnect", "code": 1000} @pytest.mark.anyio