From 91b0633fbf4d7cf6215d859acd808c37b5814bda Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 1 Nov 2022 09:38:18 +0100 Subject: [PATCH] Fix all codes --- tests/protocols/test_websocket.py | 4 +++- uvicorn/protocols/websockets/websockets_impl.py | 12 +++++++++--- uvicorn/protocols/websockets/wsproto_impl.py | 6 ++++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index add66aade..70ced34db 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -553,7 +553,9 @@ async def app(scope, receive, send): @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) -async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): +async def test_connection_lost_before_handshake_complete( + ws_protocol_cls, http_protocol_cls +): send_accept_task = asyncio.Event() disconnect_message = {} diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 87e7baa36..08a47b8ac 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -88,6 +88,7 @@ def __init__( self.closed_event = asyncio.Event() self.initial_response: Optional[HTTPResponse] = None self.connect_sent = False + self.lost_connection_before_handshake = False self.accepted_subprotocol: Optional[Subprotocol] = None self.transfer_data_task: asyncio.Task = None # type: ignore[assignment] @@ -134,6 +135,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.lost_connection_before_handshake = ( + not self.handshake_completed_event.is_set() + ) self.handshake_completed_event.set() super().connection_lost(exc) if exc is None: @@ -335,11 +339,13 @@ async def asgi_receive( await self.handshake_completed_event.wait() - if self.closed_event.is_set(): - # If client disconnected, use WebSocketServerProtocol.close_code property. + if self.lost_connection_before_handshake: # If the handshake failed or the app closed before handshake completion, # use 1006 Abnormal Closure. - return {"type": "websocket.disconnect", "code": self.close_code or 1006} + return {"type": "websocket.disconnect", "code": 1006} + + if self.closed_event.is_set(): + return {"type": "websocket.disconnect", "code": 1005} try: data = await self.recv() diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 93d1d0483..9a64ffa22 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -70,13 +70,15 @@ def connection_made(self, transport): self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc): - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1005}) + code = 1005 if self.handshake_complete else 1006 + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + self.handshake_complete = True if exc is None: self.transport.close() @@ -250,13 +252,13 @@ async def send(self, message): self.scope["client"], get_path_with_query_string(self.scope), ) - self.handshake_complete = True subprotocol = message.get("subprotocol") extra_headers = self.default_headers + list(message.get("headers", [])) extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) if not self.transport.is_closing(): + self.handshake_complete = True output = self.conn.send( wsproto.events.AcceptConnection( subprotocol=subprotocol,