diff --git a/setup.cfg b/setup.cfg index 46f4e3b99..99c7a48c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -82,7 +82,7 @@ plugins = [coverage:report] precision = 2 -fail_under = 97.82 +fail_under = 97.97 show_missing = true skip_covered = true exclude_lines = diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index a5fe93d5b..bf84a8ae2 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -5,6 +5,7 @@ from tests.protocols.test_http import HTTP_PROTOCOLS from tests.utils import run_server +from uvicorn import Server from uvicorn.config import Config from uvicorn.protocols.websockets.wsproto_impl import WSProtocol @@ -790,3 +791,32 @@ async def open_connection(url): async with run_server(config): headers = await open_connection("ws://127.0.0.1:8000") assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_shutdown_when_connection_active( + ws_protocol_cls, http_protocol_cls +): + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + ) + server = Server(config=config) + task = asyncio.create_task(server.serve(sockets=None)) + await asyncio.sleep(0.1) + async with websockets.connect("ws://127.0.0.1:8000") as websocket: + ws_conn = list(server.server_state.connections)[0] + ws_conn.shutdown() + await asyncio.sleep(0.1) + assert websocket.close_code == 1012 + assert ws_conn.transport.is_closing() + await server.shutdown() + task.cancel() diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 01133b7e2..902140d80 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -141,7 +141,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: def shutdown(self) -> None: self.ws_server.closing = True - self.transport.close() + task = asyncio.create_task(self.close(code=1012)) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task)