Skip to content

Commit

Permalink
Ensure middleware executes once per request timeout (#2615)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Dec 7, 2022
1 parent f32437b commit d404116
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
6 changes: 5 additions & 1 deletion sanic/http/http1.py
Expand Up @@ -16,6 +16,7 @@
PayloadTooLarge,
RequestCancelled,
ServerError,
ServiceUnavailable,
)
from sanic.headers import format_http1_response
from sanic.helpers import has_message_body
Expand Down Expand Up @@ -428,8 +429,11 @@ async def error_response(self, exception: Exception) -> None:
if self.request is None:
self.create_empty_request()

request_middleware = not isinstance(exception, ServiceUnavailable)
try:
await app.handle_exception(self.request, exception)
await app.handle_exception(
self.request, exception, request_middleware
)
except Exception as e:
await app.handle_exception(self.request, e, False)

Expand Down
5 changes: 4 additions & 1 deletion sanic/request.py
Expand Up @@ -104,6 +104,7 @@ class Request:
"_protocol",
"_remote_addr",
"_request_middleware_started",
"_response_middleware_started",
"_scheme",
"_socket",
"_stream_id",
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
Tuple[bool, bool, str, str], List[Tuple[str, str]]
] = defaultdict(list)
self._request_middleware_started = False
self._response_middleware_started = False
self.responded: bool = False
self.route: Optional[Route] = None
self.stream: Optional[Stream] = None
Expand Down Expand Up @@ -337,7 +339,8 @@ async def add_header(_, response: HTTPResponse):
middleware = (
self.route and self.route.extra.response_middleware
) or self.app.response_middleware
if middleware:
if middleware and not self._response_middleware_started:
self._response_middleware_started = True
response = await self.app._run_response_middleware(
self, response, middleware
)
Expand Down
30 changes: 28 additions & 2 deletions tests/test_middleware.py
@@ -1,6 +1,6 @@
import logging

from asyncio import CancelledError
from asyncio import CancelledError, sleep
from itertools import count

from sanic.exceptions import NotFound
Expand Down Expand Up @@ -318,6 +318,32 @@ async def handler(request):
resp1 = await request.respond()
return resp1

_, response = app.test_client.get("/")
app.test_client.get("/")
assert response_middleware_run_count == 1
assert request_middleware_run_count == 1


def test_middleware_run_on_timeout(app):
app.config.RESPONSE_TIMEOUT = 0.1
response_middleware_run_count = 0
request_middleware_run_count = 0

@app.on_response
def response(_, response):
nonlocal response_middleware_run_count
response_middleware_run_count += 1

@app.on_request
def request(_):
nonlocal request_middleware_run_count
request_middleware_run_count += 1

@app.get("/")
async def handler(request):
resp1 = await request.respond()
await sleep(1)
return resp1

app.test_client.get("/")
assert request_middleware_run_count == 1
assert response_middleware_run_count == 1

0 comments on commit d404116

Please sign in to comment.