From ec3aac3b00f43ad9f1a2573dea55e5dc12ec7320 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 31 Oct 2022 12:06:49 +0100 Subject: [PATCH] Check if handshake is completed before sending frame on wsproto shutdown (#1737) * Check if handshake is completed before sending frame on wsproto shutdown * Add test for connection lost before handshake is completed * Add test for close on shutdown * Increase fail-under to 97.87 * Increase coverage * Apply suggestions from code review --- setup.cfg | 2 +- tests/protocols/test_websocket.py | 61 ++++++++++++++++++- .../protocols/websockets/websockets_impl.py | 2 + uvicorn/protocols/websockets/wsproto_impl.py | 29 +++++---- 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/setup.cfg b/setup.cfg index 46f4e3b99..ac52be52e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -82,7 +82,7 @@ plugins = [coverage:report] precision = 2 -fail_under = 97.82 +fail_under = 97.92 show_missing = true skip_covered = true exclude_lines = diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index a5fe93d5b..d495f5108 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -528,7 +528,6 @@ async def app(scope, receive, send): while True: message = await receive() if message["type"] == "websocket.connect": - print("accepted") await send({"type": "websocket.accept"}) elif message["type"] == "websocket.disconnect": break @@ -551,6 +550,66 @@ async def app(scope, receive, send): assert got_disconnect_event_before_shutdown is True +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_not_accept_on_connection_lost(ws_protocol_cls, http_protocol_cls): + send_accept_task = asyncio.Event() + + async def app(scope, receive, send): + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send_accept_task.wait() + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + break + + async def websocket_session(uri): + async with websockets.client.connect(uri): + while True: + await asyncio.sleep(0.1) + + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000")) + await asyncio.sleep(0.1) + task.cancel() + send_accept_task.set() + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_send_close_on_server_shutdown(ws_protocol_cls, http_protocol_cls): + disconnect_message = {} + + async def app(scope, receive, send): + nonlocal disconnect_message + while True: + message = await receive() + if message["type"] == "websocket.connect": + await send({"type": "websocket.accept"}) + elif message["type"] == "websocket.disconnect": + disconnect_message = message + break + + async def websocket_session(uri): + async with websockets.client.connect(uri): + while True: + await asyncio.sleep(0.1) + + config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off") + async with run_server(config): + task = asyncio.create_task(websocket_session("ws://127.0.0.1:8000")) + await asyncio.sleep(0.1) + disconnect_message_before_shutdown = disconnect_message + + assert disconnect_message_before_shutdown == {} + assert disconnect_message == {"type": "websocket.disconnect", "code": 1012} + task.cancel() + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 01133b7e2..87e7baa36 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -345,6 +345,8 @@ async def asgi_receive( data = await self.recv() except ConnectionClosed as exc: self.closed_event.set() + if self.ws_server.closing: + return {"type": "websocket.disconnect", "code": 1012} return {"type": "websocket.disconnect", "code": exc.code} msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item] diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 65af7d308..a97766ff5 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -122,9 +122,12 @@ def resume_writing(self): self.writable.set() def shutdown(self): - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - output = self.conn.send(wsproto.events.CloseConnection(code=1012)) - self.transport.write(output) + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + output = self.conn.send(wsproto.events.CloseConnection(code=1012)) + self.transport.write(output) + else: + self.send_500_response() self.transport.close() def on_task_complete(self, task): @@ -219,9 +222,8 @@ def send_500_response(self): async def run_asgi(self): try: result = await self.app(self.scope, self.receive, self.send) - except BaseException as exc: - msg = "Exception in ASGI application\n" - self.logger.error(msg, exc_info=exc) + except BaseException: + self.logger.exception("Exception in ASGI application\n") if not self.handshake_complete: self.send_500_response() self.transport.close() @@ -254,14 +256,15 @@ async def send(self, message): extensions = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) - output = self.conn.send( - wsproto.events.AcceptConnection( - subprotocol=subprotocol, - extensions=extensions, - extra_headers=extra_headers, + if not self.transport.is_closing(): + output = self.conn.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extensions=extensions, + extra_headers=extra_headers, + ) ) - ) - self.transport.write(output) + self.transport.write(output) elif message_type == "websocket.close": self.queue.put_nowait({"type": "websocket.disconnect", "code": None})