From 400e240f84f987a27aba9c4bb7d0b228b47c5e14 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Tue, 1 Feb 2022 16:08:24 +0200 Subject: [PATCH] Make error handler run always (#761) * Error handler call always * Add tests * Add docs * Only run response callable if response didn't start Co-authored-by: Marcelo Trylesinski --- docs/exceptions.md | 18 ++++++++++++++++++ starlette/middleware/errors.py | 26 +++++++++++++------------- tests/middleware/test_errors.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index bf460d229..9818a2045 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -75,6 +75,24 @@ should bubble through the entire middleware stack as exceptions. Any error logging middleware should ensure that it re-raises the exception all the way up to the server. +In practical terms, the error handled used is `exception_handler[500]` or `exception_handler[Exception]`. +Both keys `500` and `Exception` can be used. See below: + +```python +async def handle_error(request: Request, exc: HTTPException): + # Perform some logic + return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) + +exception_handlers = { + Exception: handle_error # or "500: handle_error" +} +``` + +It's important to notice that in case a [`BackgroundTask`](https://www.starlette.io/background/) raises an exception, +it will be handled by the `handle_error` function, but at that point, the response was already sent. In other words, +the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then +it will use the response object - in the above example, the returned `JSONResponse`. + In order to deal with this behaviour correctly, the middleware stack of a `Starlette` application is configured like this: diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index 30f5570ca..474c9afc0 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -158,21 +158,21 @@ async def _send(message: Message) -> None: try: await self.app(scope, receive, _send) except Exception as exc: - if not response_started: - request = Request(scope) - if self.debug: - # In debug mode, return traceback responses. - response = self.debug_response(request, exc) - elif self.handler is None: - # Use our default 500 error handler. - response = self.error_response(request, exc) + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if asyncio.iscoroutinefunction(self.handler): + response = await self.handler(request, exc) else: - # Use an installed 500 error handler. - if asyncio.iscoroutinefunction(self.handler): - response = await self.handler(request, exc) - else: - response = await run_in_threadpool(self.handler, request, exc) + response = await run_in_threadpool(self.handler, request, exc) + if not response_started: await response(scope, receive, send) # We always continue to raise the exception. diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 2c926a9b2..392c2ba16 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -1,7 +1,10 @@ import pytest +from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response +from starlette.routing import Route def test_handler(test_client_factory): @@ -68,3 +71,28 @@ async def app(scope, receive, send): client = test_client_factory(app) with client.websocket_connect("/"): pass # pragma: nocover + + +def test_background_task(test_client_factory): + accessed_error_handler = False + + def error_handler(request, exc): + nonlocal accessed_error_handler + accessed_error_handler = True + + def raise_exception(): + raise Exception("Something went wrong") + + async def endpoint(request): + task = BackgroundTask(raise_exception) + return Response(status_code=204, background=task) + + app = Starlette( + routes=[Route("/", endpoint=endpoint)], + exception_handlers={Exception: error_handler}, + ) + + client = test_client_factory(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 204 + assert accessed_error_handler