diff --git a/docs/exceptions.md b/docs/exceptions.md index 9818a2045..f97f1af89 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -62,6 +62,17 @@ async def http_exception(request: Request, exc: HTTPException): ) ``` +You might also want to override how `WebSocketException` is handled: + +```python +async def websocket_exception(websocket: WebSocket, exc: WebSocketException): + await websocket.close(code=1008) + +exception_handlers = { + WebSocketException: websocket_exception +} +``` + ## Errors and handled exceptions It is important to differentiate between handled exceptions and errors. @@ -112,3 +123,11 @@ returning plain-text HTTP responses for any `HTTPException`. You should only raise `HTTPException` inside routing or endpoints. Middleware classes should instead just return appropriate responses directly. + +## WebSocketException + +You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints. + +* `WebSocketException(code=1008, reason=None)` + +You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1). diff --git a/starlette/exceptions.py b/starlette/exceptions.py index 2b5acddb5..87da73591 100644 --- a/starlette/exceptions.py +++ b/starlette/exceptions.py @@ -2,7 +2,7 @@ import typing import warnings -__all__ = ("HTTPException",) +__all__ = ("HTTPException", "WebSocketException") class HTTPException(Exception): @@ -23,6 +23,16 @@ def __repr__(self) -> str: return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" +class WebSocketException(Exception): + def __init__(self, code: int, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(code={self.code!r}, reason={self.reason!r})" + + __deprecated__ = "ExceptionMiddleware" diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 42fd41ae2..cd7294170 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -2,10 +2,11 @@ from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket class ExceptionMiddleware: @@ -22,7 +23,10 @@ def __init__( self._status_handlers: typing.Dict[int, typing.Callable] = {} self._exception_handlers: typing.Dict[ typing.Type[Exception], typing.Callable - ] = {HTTPException: self.http_exception} + ] = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, + } if handlers is not None: for key, value in handlers.items(): self.add_exception_handler(key, value) @@ -47,7 +51,7 @@ def _lookup_exception_handler( return None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": + if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return @@ -78,12 +82,19 @@ async def sender(message: Message) -> None: msg = "Caught handled exception, but response already started." raise RuntimeError(msg) from exc - request = Request(scope, receive=receive) - if is_async_callable(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) + if scope["type"] == "http": + request = Request(scope, receive=receive) + if is_async_callable(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, sender) + elif scope["type"] == "websocket": + websocket = WebSocket(scope, receive=receive, send=send) + if is_async_callable(handler): + await handler(websocket, exc) + else: + await run_in_threadpool(handler, websocket, exc) def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: @@ -91,3 +102,8 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: return PlainTextResponse( exc.detail, status_code=exc.status_code, headers=exc.headers ) + + async def websocket_exception( + self, websocket: WebSocket, exc: WebSocketException + ) -> None: + await websocket.close(code=exc.code, reason=exc.reason) diff --git a/tests/test_applications.py b/tests/test_applications.py index 0d0ede571..2cee601b0 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,16 +1,19 @@ import os from contextlib import asynccontextmanager +import anyio import pytest +from starlette import status from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.websockets import WebSocket async def error_500(request, exc): @@ -61,6 +64,24 @@ async def websocket_endpoint(session): await session.close() +async def websocket_raise_websocket(websocket: WebSocket): + await websocket.accept() + raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) + + +class CustomWSException(Exception): + pass + + +async def websocket_raise_custom(websocket: WebSocket): + await websocket.accept() + raise CustomWSException() + + +def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException): + anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) + + users = Router( routes=[ Route("/", endpoint=all_users_page), @@ -78,6 +99,7 @@ async def websocket_endpoint(session): 500: error_500, 405: method_not_allowed, HTTPException: http_exception, + CustomWSException: custom_ws_exception_handler, } middleware = [ @@ -91,6 +113,8 @@ async def websocket_endpoint(session): Route("/class", endpoint=Homepage), Route("/500", endpoint=runtime_error), WebSocketRoute("/ws", endpoint=websocket_endpoint), + WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket), + WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), Mount("/users", app=users), Host("{subdomain}.example.org", app=subdomain), ], @@ -180,6 +204,26 @@ def test_500(test_client_factory): assert response.json() == {"detail": "Server Error"} +def test_websocket_raise_websocket_exception(client): + with client.websocket_connect("/ws-raise-websocket") as session: + response = session.receive() + assert response == { + "type": "websocket.close", + "code": status.WS_1003_UNSUPPORTED_DATA, + "reason": "", + } + + +def test_websocket_raise_custom_exception(client): + with client.websocket_connect("/ws-raise-custom") as session: + response = session.receive() + assert response == { + "type": "websocket.close", + "code": status.WS_1013_TRY_AGAIN_LATER, + "reason": "", + } + + def test_middleware(test_client_factory): client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") @@ -194,6 +238,8 @@ def test_routes(): Route("/class", endpoint=Homepage), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), + WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket), + WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom), Mount( "/users", app=Router( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 9acd42154..05583a430 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,7 +2,7 @@ import pytest -from starlette.exceptions import HTTPException +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -119,7 +119,7 @@ async def app(scope, receive, send): assert response.text == "" -def test_repr(): +def test_http_repr(): assert repr(HTTPException(404)) == ( "HTTPException(status_code=404, detail='Not Found')" ) @@ -135,6 +135,20 @@ class CustomHTTPException(HTTPException): ) +def test_websocket_repr(): + assert repr(WebSocketException(1008, reason="Policy Violation")) == ( + "WebSocketException(code=1008, reason='Policy Violation')" + ) + + class CustomWebSocketException(WebSocketException): + pass + + assert ( + repr(CustomWebSocketException(1013, reason="Something custom")) + == "CustomWebSocketException(code=1013, reason='Something custom')" + ) + + def test_exception_middleware_deprecation() -> None: # this test should be removed once the deprecation shim is removed with pytest.warns(DeprecationWarning):