Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocketException and support for WS handlers #1263

Merged
merged 20 commits into from Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/exceptions.md
Expand Up @@ -62,6 +62,17 @@ async def http_exception(request: Request, exc: HTTPException):
)
```

You might also want to override how `WebSocketException` is handled:

```python
async def websocket_exception(websocket: WebSocket, exc: WebSocketException):
await websocket.close(code=1008)

exception_handlers = {
WebSocketException: websocket_exception
}
```

## Errors and handled exceptions

It is important to differentiate between handled exceptions and errors.
Expand Down Expand Up @@ -112,3 +123,11 @@ returning plain-text HTTP responses for any `HTTPException`.

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.

## WebSocketException

You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.

* `WebSocketException(code=1008, reason=None)`

You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
12 changes: 11 additions & 1 deletion starlette/exceptions.py
Expand Up @@ -2,7 +2,7 @@
import typing
import warnings

__all__ = ("HTTPException",)
__all__ = ("HTTPException", "WebSocketException")


class HTTPException(Exception):
Expand All @@ -23,6 +23,16 @@ def __repr__(self) -> str:
return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})"


class WebSocketException(Exception):
def __init__(self, code: int, reason: typing.Optional[str] = None) -> None:
self.code = code
self.reason = reason or ""

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(code={self.code!r}, reason={self.reason!r})"


__deprecated__ = "ExceptionMiddleware"


Expand Down
34 changes: 25 additions & 9 deletions starlette/middleware/exceptions.py
Expand Up @@ -2,10 +2,11 @@

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket


class ExceptionMiddleware:
Expand All @@ -22,7 +23,10 @@ def __init__(
self._status_handlers: typing.Dict[int, typing.Callable] = {}
self._exception_handlers: typing.Dict[
typing.Type[Exception], typing.Callable
] = {HTTPException: self.http_exception}
] = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
}
if handlers is not None:
for key, value in handlers.items():
self.add_exception_handler(key, value)
Expand All @@ -47,7 +51,7 @@ def _lookup_exception_handler(
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

Expand Down Expand Up @@ -78,16 +82,28 @@ async def sender(message: Message) -> None:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
if scope["type"] == "http":
request = Request(scope, receive=receive)
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if is_async_callable(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)
return PlainTextResponse(
exc.detail, status_code=exc.status_code, headers=exc.headers
)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
48 changes: 47 additions & 1 deletion tests/test_applications.py
@@ -1,16 +1,19 @@
import os
from contextlib import asynccontextmanager

import anyio
import pytest

from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.websockets import WebSocket


async def error_500(request, exc):
Expand Down Expand Up @@ -61,6 +64,24 @@ async def websocket_endpoint(session):
await session.close()


async def websocket_raise_websocket(websocket: WebSocket):
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)


class CustomWSException(Exception):
pass


async def websocket_raise_custom(websocket: WebSocket):
await websocket.accept()
raise CustomWSException()


def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)


users = Router(
routes=[
Route("/", endpoint=all_users_page),
Expand All @@ -78,6 +99,7 @@ async def websocket_endpoint(session):
500: error_500,
405: method_not_allowed,
HTTPException: http_exception,
CustomWSException: custom_ws_exception_handler,
}

middleware = [
Expand All @@ -91,6 +113,8 @@ async def websocket_endpoint(session):
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain),
],
Expand Down Expand Up @@ -180,6 +204,26 @@ def test_500(test_client_factory):
assert response.json() == {"detail": "Server Error"}


def test_websocket_raise_websocket_exception(client):
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1003_UNSUPPORTED_DATA,
"reason": "",
}


def test_websocket_raise_custom_exception(client):
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1013_TRY_AGAIN_LATER,
"reason": "",
}


def test_middleware(test_client_factory):
client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
Expand All @@ -194,6 +238,8 @@ def test_routes():
Route("/class", endpoint=Homepage),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
Mount(
"/users",
app=Router(
Expand Down
18 changes: 16 additions & 2 deletions tests/test_exceptions.py
Expand Up @@ -2,7 +2,7 @@

import pytest

from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
Expand Down Expand Up @@ -119,7 +119,7 @@ async def app(scope, receive, send):
assert response.text == ""


def test_repr():
def test_http_repr():
assert repr(HTTPException(404)) == (
"HTTPException(status_code=404, detail='Not Found')"
)
Expand All @@ -135,6 +135,20 @@ class CustomHTTPException(HTTPException):
)


def test_websocket_repr():
assert repr(WebSocketException(1008, reason="Policy Violation")) == (
"WebSocketException(code=1008, reason='Policy Violation')"
)

class CustomWebSocketException(WebSocketException):
pass

assert (
repr(CustomWebSocketException(1013, reason="Something custom"))
== "CustomWebSocketException(code=1013, reason='Something custom')"
)


def test_exception_middleware_deprecation() -> None:
# this test should be removed once the deprecation shim is removed
with pytest.warns(DeprecationWarning):
Expand Down