Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocketException and support for WS handlers #1263

Merged
merged 20 commits into from Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/exceptions.md
Expand Up @@ -62,6 +62,14 @@ async def http_exception(request: Request, exc: HTTPException):
)
```

You might also want to override how `WebSocketException` is handled:

```python
@app.exception_handler(WebSocketException)
async def websocket_exception(websocket, exc):
await websocket.close(code=1008)
```

## Errors and handled exceptions

It is important to differentiate between handled exceptions and errors.
Expand Down Expand Up @@ -94,3 +102,11 @@ returning plain-text HTTP responses for any `HTTPException`.

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.

## WebSocketException

You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.

* `WebSocketException(code=1008)`

You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
41 changes: 33 additions & 8 deletions starlette/exceptions.py
Expand Up @@ -6,6 +6,7 @@
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket


class HTTPException(Exception):
Expand All @@ -23,6 +24,15 @@ def __repr__(self) -> str:
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"


class WebSocketException(Exception):
def __init__(self, code: int) -> None:
Kludex marked this conversation as resolved.
Show resolved Hide resolved
self.code = code

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r})"


class ExceptionMiddleware:
def __init__(
self,
Expand All @@ -37,7 +47,10 @@ def __init__(
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {HTTPException: self.http_exception}
] = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)
Expand All @@ -62,7 +75,7 @@ def _lookup_exception_handler(
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in {"http", "websocket"}:
await self.app(scope, receive, send)
return

Expand Down Expand Up @@ -93,16 +106,28 @@ async def sender(message: Message) -> None:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
if scope["type"] == "http":
request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if asyncio.iscoroutinefunction(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code)
13 changes: 11 additions & 2 deletions starlette/middleware/errors.py
Expand Up @@ -4,10 +4,12 @@
import traceback
import typing

from starlette import status
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket

STYLES = """
p {
Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(
self.debug = debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in {"http", "websocket"}:
await self.app(scope, receive, send)
return

Expand All @@ -158,7 +160,7 @@ async def _send(message: Message) -> None:
try:
await self.app(scope, receive, _send)
except Exception as exc:
if not response_started:
if scope["type"] == "http" and not response_started:
request = Request(scope)
if self.debug:
# In debug mode, return traceback responses.
Expand All @@ -174,6 +176,13 @@ async def _send(message: Message) -> None:
response = await run_in_threadpool(self.handler, request, exc)

await response(scope, receive, send)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive, send)
# https://tools.ietf.org/html/rfc6455#section-7.4.1
# 1011 indicates that a server is terminating the connection because
# it encountered an unexpected condition that prevented it from
# fulfilling the request.
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)

# We always continue to raise the exception.
# This allows servers to log the error, or allows test clients
Expand Down
9 changes: 3 additions & 6 deletions tests/middleware/test_errors.py
Expand Up @@ -2,6 +2,7 @@

from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import JSONResponse, Response
from starlette.websockets import WebSocketDisconnect


def test_handler(test_client_factory):
Expand Down Expand Up @@ -54,17 +55,13 @@ async def app(scope, receive, send):
client.get("/")


def test_debug_not_http(test_client_factory):
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""

def test_debug_websocket(test_client_factory):
async def app(scope, receive, send):
raise RuntimeError("Something went wrong")

app = ServerErrorMiddleware(app)

with pytest.raises(RuntimeError):
with pytest.raises(WebSocketDisconnect):
client = test_client_factory(app)
with client.websocket_connect("/"):
pass # pragma: nocover
50 changes: 49 additions & 1 deletion tests/test_applications.py
@@ -1,15 +1,18 @@
import os
import sys

import anyio
import pytest

from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.websockets import WebSocket

if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
Expand Down Expand Up @@ -93,6 +96,27 @@ async def websocket_endpoint(session):
await session.close()


@app.websocket_route("/ws-raise-websocket")
async def websocket_raise_websocket_exception(websocket: WebSocket):
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)


class CustomWSException(Exception):
pass


@app.websocket_route("/ws-raise-custom")
async def websocket_raise_custom(websocket: WebSocket):
await websocket.accept()
raise CustomWSException()


@app.exception_handler(CustomWSException)
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)


@pytest.fixture
def client(test_client_factory):
with test_client_factory(app) as client:
Expand Down Expand Up @@ -174,6 +198,26 @@ def test_500(test_client_factory):
assert response.json() == {"detail": "Server Error"}


def test_websocket_raise_websocket_exception(client):
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1003_UNSUPPORTED_DATA,
"reason": "",
}


def test_websocket_raise_custom_exception(client):
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1013_TRY_AGAIN_LATER,
"reason": "",
}


def test_middleware(test_client_factory):
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
Expand Down Expand Up @@ -201,6 +245,10 @@ def test_routes():
),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute(
"/ws-raise-websocket", endpoint=websocket_raise_websocket_exception
),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
]


Expand Down
13 changes: 11 additions & 2 deletions tests/test_exceptions.py
@@ -1,6 +1,6 @@
import pytest

from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.exceptions import ExceptionMiddleware, HTTPException, WebSocketException
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute

Expand Down Expand Up @@ -108,7 +108,7 @@ def app(scope):
assert response.text == ""


def test_repr():
def test_http_repr():
assert repr(HTTPException(404)) == (
"HTTPException(status_code=404, detail='Not Found')"
)
Expand All @@ -122,3 +122,12 @@ class CustomHTTPException(HTTPException):
assert repr(CustomHTTPException(500, detail="Something custom")) == (
"CustomHTTPException(status_code=500, detail='Something custom')"
)


def test_websocket_repr():
assert repr(WebSocketException(1008)) == ("WebSocketException(code=1008)")

class CustomWebSocketException(WebSocketException):
pass

assert repr(CustomWebSocketException(1013)) == "CustomWebSocketException(code=1013)"