diff --git a/docs/websockets.md b/docs/websockets.md index 43406aced..1128bce43 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da ### Closing the connection -* `await websocket.close(code=1000)` +* `await websocket.close(code=1000, reason=None)` ### Sending and receiving messages diff --git a/starlette/testclient.py b/starlette/testclient.py index 0b4bc78d1..c951767b4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -352,7 +352,9 @@ async def _asgi_send(self, message: Message) -> None: def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": - raise WebSocketDisconnect(message.get("code", 1000)) + raise WebSocketDisconnect( + message.get("code", 1000), message.get("reason", "") + ) def send(self, message: Message) -> None: self._receive_queue.put(message) diff --git a/starlette/websockets.py b/starlette/websockets.py index bf4cca83f..da7406047 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum): class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" class WebSocket(HTTPConnection): @@ -146,13 +147,18 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None: else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) - async def close(self, code: int = 1000) -> None: - await self.send({"type": "websocket.close", "code": code}) + async def close(self, code: int = 1000, reason: str = None) -> None: + await self.send( + {"type": "websocket.close", "code": code, "reason": reason or ""} + ) class WebSocketClose: - def __init__(self, code: int = 1000) -> None: + def __init__(self, code: int = 1000, reason: str = None) -> None: self.code = code + self.reason = reason or "" async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await send({"type": "websocket.close", "code": self.code}) + await send( + {"type": "websocket.close", "code": self.code, "reason": self.reason} + ) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index f3242d115..b11685cbc 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -405,3 +405,20 @@ async def mock_send(message): assert websocket == websocket assert websocket in {websocket} assert {websocket} == {websocket} + + +def test_websocket_close_reason(test_client_factory) -> None: + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + with pytest.raises(WebSocketDisconnect) as exc: + websocket.receive_text() + assert exc.value.code == status.WS_1001_GOING_AWAY + assert exc.value.reason == "Going Away"