diff --git a/sanic/http/http1.py b/sanic/http/http1.py index ccfae75de6..0cf2fd1177 100644 --- a/sanic/http/http1.py +++ b/sanic/http/http1.py @@ -16,6 +16,7 @@ PayloadTooLarge, RequestCancelled, ServerError, + ServiceUnavailable, ) from sanic.headers import format_http1_response from sanic.helpers import has_message_body @@ -428,8 +429,11 @@ async def error_response(self, exception: Exception) -> None: if self.request is None: self.create_empty_request() + request_middleware = not isinstance(exception, ServiceUnavailable) try: - await app.handle_exception(self.request, exception) + await app.handle_exception( + self.request, exception, request_middleware + ) except Exception as e: await app.handle_exception(self.request, e, False) diff --git a/sanic/request.py b/sanic/request.py index f7b3b9993e..8b8d95305f 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -104,6 +104,7 @@ class Request: "_protocol", "_remote_addr", "_request_middleware_started", + "_response_middleware_started", "_scheme", "_socket", "_stream_id", @@ -179,6 +180,7 @@ def __init__( Tuple[bool, bool, str, str], List[Tuple[str, str]] ] = defaultdict(list) self._request_middleware_started = False + self._response_middleware_started = False self.responded: bool = False self.route: Optional[Route] = None self.stream: Optional[Stream] = None @@ -337,7 +339,8 @@ async def add_header(_, response: HTTPResponse): middleware = ( self.route and self.route.extra.response_middleware ) or self.app.response_middleware - if middleware: + if middleware and not self._response_middleware_started: + self._response_middleware_started = True response = await self.app._run_response_middleware( self, response, middleware ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index a20344504b..6589f4a44e 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,6 +1,6 @@ import logging -from asyncio import CancelledError +from asyncio import CancelledError, sleep from itertools import count from sanic.exceptions import NotFound @@ -318,6 +318,32 @@ async def handler(request): resp1 = await request.respond() return resp1 - _, response = app.test_client.get("/") + app.test_client.get("/") assert response_middleware_run_count == 1 assert request_middleware_run_count == 1 + + +def test_middleware_run_on_timeout(app): + app.config.RESPONSE_TIMEOUT = 0.1 + response_middleware_run_count = 0 + request_middleware_run_count = 0 + + @app.on_response + def response(_, response): + nonlocal response_middleware_run_count + response_middleware_run_count += 1 + + @app.on_request + def request(_): + nonlocal request_middleware_run_count + request_middleware_run_count += 1 + + @app.get("/") + async def handler(request): + resp1 = await request.respond() + await sleep(1) + return resp1 + + app.test_client.get("/") + assert request_middleware_run_count == 1 + assert response_middleware_run_count == 1