Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move exception handling logic to endpoints #2020

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion starlette/middleware/exceptions.py
Expand Up @@ -55,6 +55,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

scope["starlette.exception_handlers"] = (
self._exception_handlers,
self._status_handlers,
)

response_started = False

async def sender(message: Message) -> None:
Expand Down Expand Up @@ -106,4 +111,4 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response:
async def websocket_exception(
self, websocket: WebSocket, exc: WebSocketException
) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover
110 changes: 102 additions & 8 deletions starlette/routing.py
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -53,19 +53,72 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
return inspect.iscoroutinefunction(obj)


def _lookup_exception_handler(
exc: Exception,
handlers: typing.Mapping[typing.Type[Exception], typing.Callable[..., typing.Any]],
) -> typing.Optional[typing.Callable[..., typing.Any]]:
for cls in type(exc).__mro__:
if cls in handlers:
return handlers[cls]
return None


def request_response(func: typing.Callable) -> ASGIApp:
"""
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""

is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
exception_handlers: typing.Mapping[
typing.Type[Exception], typing.Callable[..., typing.Any]
]
status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]]

try:
exception_handlers, status_handlers = scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

request = Request(scope, receive=receive, send=sender)

try:
if is_coroutine:
response = await func(request)
else:
response = await run_in_threadpool(func, request)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exc, exception_handlers)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)

await response(scope, receive, send)

return app
Expand All @@ -78,8 +131,49 @@ def websocket_session(func: typing.Callable) -> ASGIApp:
# assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"

async def app(scope: Scope, receive: Receive, send: Send) -> None:
session = WebSocket(scope, receive=receive, send=send)
await func(session)
exception_handlers: typing.Mapping[
typing.Type[Exception], typing.Callable[..., typing.Any]
]
status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]]

try:
exception_handlers, status_handlers = scope["starlette.exception_handlers"]
except KeyError:
exception_handlers, status_handlers = {}, {}

response_started = False

async def sender(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True
await send(message)

session = WebSocket(scope, receive=receive, send=sender)

try:
await func(session)
except Exception as exc:
handler = None

if isinstance(exc, HTTPException):
handler = status_handlers.get(exc.status_code)

if handler is None:
handler = _lookup_exception_handler(exc, exception_handlers)

if handler is None:
raise exc

if response_started:
msg = "Caught handled exception, but response already started."
raise RuntimeError(msg) from exc

if is_async_callable(handler):
await handler(session, exc)
else:
await run_in_threadpool(handler, session, exc)

return app

Expand Down
35 changes: 33 additions & 2 deletions tests/test_exceptions.py
Expand Up @@ -4,7 +4,8 @@

from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import PlainTextResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute


Expand All @@ -28,6 +29,22 @@ def with_headers(request):
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
await request.body()
raise BadBodyException(422)


async def handler_that_reads_body(
request: Request, exc: BadBodyException
) -> JSONResponse:
body = await request.body()
return JSONResponse(status_code=422, content={"body": body.decode()})


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
response = PlainTextResponse("OK", status_code=200)
Expand All @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send):
Route("/with_headers", endpoint=with_headers),
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()),
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
Route(
"/consume_body_in_endpoint_and_handler",
endpoint=read_body_and_raise_exc,
methods=["POST"],
),
]
)


app = ExceptionMiddleware(router)
app = ExceptionMiddleware(
router,
handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item]
)


@pytest.fixture
Expand Down Expand Up @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None:

with pytest.warns(DeprecationWarning):
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}
6 changes: 1 addition & 5 deletions tests/test_routing.py
Expand Up @@ -945,13 +945,9 @@ async def modified_send(msg: Message) -> None:
assert resp.status_code == 200, resp.content
assert "X-Mounted" in resp.headers

# this is the "surprising" behavior bit
# the middleware on the mount never runs because there
# is nothing to catch the HTTPException
# since Mount middlweare is not wrapped by ExceptionMiddleware
resp = client.get("/mount/err")
assert resp.status_code == 403, resp.content
assert "X-Mounted" not in resp.headers
assert "X-Mounted" in resp.headers


def test_route_repr() -> None:
Expand Down