Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocketException and support for WS handlers #527

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/exceptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ async def http_exception(request, exc):
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
```

You might also want to override how `WebSocketException` is handled:

```python
@app.exception_handler(WebSocketException)
async def websocket_exception(websocket, exc):
await websocket.close(code=1008)
```

## Errors and handled exceptions

It is important to differentiate between handled exceptions and errors.
Expand Down Expand Up @@ -74,3 +82,11 @@ returning plain-text HTTP responses for any `HTTPException`.

You should only raise `HTTPException` inside routing or endpoints. Middleware
classes should instead just return appropriate responses directly.

## WebSocketException

You can use the `WebSocketException` class to raise errors inside of WebSocket endpoints.

* `WebSocketException(code=1008)`

You can set any code valid as defined [in the specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
48 changes: 40 additions & 8 deletions starlette/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import http
import typing

from starlette import status
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


class HTTPException(Exception):
Expand All @@ -16,13 +18,31 @@ def __init__(self, status_code: int, detail: str = None) -> None:
self.detail = detail


class WebSocketException(Exception):
def __init__(self, code: int = status.WS_1008_POLICY_VIOLATION) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reckon we should probably have code as a mandatory argument here, in line with HTTPException.

"""
`code` defaults to 1008, from the WebSocket specification:

> 1008 indicates that an endpoint is terminating the connection
> because it has received a message that violates its policy. This
> is a generic status code that can be returned when there is no
> other more suitable status code (e.g., 1003 or 1009) or if there
> is a need to hide specific details about the policy.

Set `code` to any value allowed by
[the WebSocket specification](https://tools.ietf.org/html/rfc6455#section-7.4.1).
"""
self.code = code


class ExceptionMiddleware:
def __init__(self, app: ASGIApp, debug: bool = False) -> None:
self.app = app
self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
self._status_handlers = {} # type: typing.Dict[int, typing.Callable]
self._exception_handlers = {
HTTPException: self.http_exception
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
} # type: typing.Dict[typing.Type[Exception], typing.Callable]

def add_exception_handler(
Expand All @@ -45,7 +65,7 @@ def _lookup_exception_handler(
return None

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in {"http", "websocket"}:
await self.app(scope, receive, send)
return

Expand Down Expand Up @@ -76,14 +96,26 @@ async def sender(message: Message) -> None:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
if scope["type"] == "http":
request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
await response(scope, receive, sender)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
if asyncio.iscoroutinefunction(handler):
await handler(websocket, exc)
else:
await run_in_threadpool(handler, websocket, exc)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(b"", status_code=exc.status_code)
return PlainTextResponse(exc.detail, status_code=exc.status_code)

async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code)
15 changes: 12 additions & 3 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import traceback
import typing

from starlette import status
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.requests import Request, empty_receive
from starlette.responses import HTMLResponse, PlainTextResponse, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketState

STYLES = """
.traceback-container {
Expand Down Expand Up @@ -83,7 +85,7 @@ def __init__(
self.debug = debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
if scope["type"] not in {"http", "websocket"}:
await self.app(scope, receive, send)
return

Expand All @@ -99,7 +101,7 @@ async def _send(message: Message) -> None:
try:
await self.app(scope, receive, _send)
except Exception as exc:
if not response_started:
if not response_started and scope["type"] == "http":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use if scope["type"] == "http" and not response_started

request = Request(scope)
if self.debug:
# In debug mode, return traceback responses.
Expand All @@ -115,6 +117,13 @@ async def _send(message: Message) -> None:
response = await run_in_threadpool(self.handler, request, exc)

await response(scope, receive, send)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive, send)
# https://tools.ietf.org/html/rfc6455#section-7.4.1
# 1011 indicates that a server is terminating the connection because
# it encountered an unexpected condition that prevented it from
# fulfilling the request.
await websocket.close(code=status.WS_1011_INTERNAL_ERROR)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting...

We shouldn't really have to do this, because the webserver's behavior ought to be in line with this anyway.
Ie. A higher priority task for us should be to ensure that uvicorn has equivelent handling, and will close the websocker with a 1011 code if it gets and exception, and the socket is still open.

We'd also want to check in this case that the websocket is in an open state. (The exception could have been raised after the websocket close.) I think that means we'd want a websocket_complete = False alongside the existing response_started = False, and track its state in the _send() function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also not really clear if we want/need this in any case - In the HTTP case, we can use the handler to issue an application-custom 500 page. In the websocket case there's really nothing available for us to do other than close the the connection, which the server ought to take care of anyways.


# We always continue to raise the exception.
# This allows servers to log the error, or allows test clients
Expand Down
9 changes: 3 additions & 6 deletions tests/middleware/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import JSONResponse, Response
from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect


def test_handler():
Expand Down Expand Up @@ -55,16 +56,12 @@ async def app(scope, receive, send):
client.get("/")


def test_debug_not_http():
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""

def test_debug_websocket():
async def app(scope, receive, send):
raise RuntimeError("Something went wrong")

app = ServerErrorMiddleware(app)

with pytest.raises(RuntimeError):
with pytest.raises(WebSocketDisconnect):
client = TestClient(app)
client.websocket_connect("/")
53 changes: 52 additions & 1 deletion tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import asyncio
import os

import pytest

from starlette import status
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect

app = Starlette()

Expand Down Expand Up @@ -86,6 +91,28 @@ async def websocket_endpoint(session):
await session.close()


@app.websocket_route("/ws-raise-websocket")
async def websocket_raise_websocket_exception(websocket):
await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)


class CustomWSException(Exception):
pass


@app.websocket_route("/ws-raise-custom")
async def websocket_raise_custom(websocket):
await websocket.accept()
raise CustomWSException()


@app.exception_handler(CustomWSException)
def custom_ws_exception_handler(websocket, exc):
loop = asyncio.new_event_loop()
loop.run_until_complete(websocket.close(code=status.WS_1013_TRY_AGAIN_LATER))


client = TestClient(app)


Expand Down Expand Up @@ -164,6 +191,26 @@ def test_500():
assert response.json() == {"detail": "Server Error"}


def test_websocket_raise_websocket_exception():
client = TestClient(app)
with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1003_UNSUPPORTED_DATA,
}


def test_websocket_raise_custom_exception():
client = TestClient(app)
with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive()
assert response == {
"type": "websocket.close",
"code": status.WS_1013_TRY_AGAIN_LATER,
}


def test_middleware():
client = TestClient(app, base_url="http://incorrecthost")
response = client.get("/func")
Expand Down Expand Up @@ -191,6 +238,10 @@ def test_routes():
),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute(
"/ws-raise-websocket", endpoint=websocket_raise_websocket_exception
),
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
]


Expand Down