Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reason to WebSocket closure #1417

Merged
merged 20 commits into from Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/websockets.md
Expand Up @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da

### Closing the connection

* `await websocket.close(code=1000)`
* `await websocket.close(code=1000, reason=None)`

### Sending and receiving messages

Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Expand Up @@ -352,7 +352,7 @@ async def _asgi_send(self, message: Message) -> None:

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
raise WebSocketDisconnect(message.get("code", 1000), message.get("reason"))
aminalaee marked this conversation as resolved.
Show resolved Hide resolved

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
16 changes: 11 additions & 5 deletions starlette/websockets.py
Expand Up @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason or ""


class WebSocket(HTTPConnection):
Expand Down Expand Up @@ -144,13 +145,18 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})

async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
async def close(self, code: int = 1000, reason: str = None) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)


class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason or ""

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
17 changes: 17 additions & 0 deletions tests/test_websockets.py
Expand Up @@ -391,3 +391,20 @@ async def mock_send(message):
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}


def test_websocket_close_reason(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Going Away"