diff --git a/CHANGES/7056.feature b/CHANGES/7056.feature new file mode 100644 index 00000000000..102fb4d7938 --- /dev/null +++ b/CHANGES/7056.feature @@ -0,0 +1 @@ +Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito` diff --git a/aiohttp/web.py b/aiohttp/web.py index e3e75779c6f..69d46bbf49f 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -307,6 +307,7 @@ async def _run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, ) -> None: # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): @@ -321,6 +322,7 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + handler_cancellation=handler_cancellation, ) await runner.setup() @@ -481,6 +483,7 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" @@ -513,6 +516,7 @@ def run_app( handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, + handler_cancellation=handler_cancellation, ) ) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 45b6f423fc1..27c815a4461 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -313,6 +313,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: super().connection_lost(exc) + # Grab value before setting _manager to None. + handler_cancellation = self._manager.handler_cancellation + self._manager = None self._force_close = True self._request_factory = None @@ -330,6 +333,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._waiter is not None: self._waiter.cancel() + if handler_cancellation and self._task_handler is not None: + self._task_handler.cancel() + self._task_handler = None if self._payload_parser is not None: diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 40211463f37..a3d658afbff 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -19,6 +19,7 @@ def __init__( *, request_factory: Optional[_RequestFactory] = None, debug: Optional[bool] = None, + handler_cancellation: bool = False, **kwargs: Any, ) -> None: if debug is not None: @@ -33,6 +34,7 @@ def __init__( self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request + self.handler_cancellation = handler_cancellation @property def connections(self) -> List[RequestHandler]: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index f79dd050c3b..44d9b12fcdd 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2809,7 +2809,8 @@ Utilities access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ - reuse_port=None) + reuse_port=None, \ + handler_cancellation=False) A high-level function for running an application, serving it until keyboard interrupt and performing a @@ -2905,6 +2906,9 @@ Utilities this flag when being created. This option is not supported on Windows. + :param bool handler_cancellation: cancels the web handler task if the client + drops the connection. + .. versionadded:: 3.0 Support *access_log_class* parameter. @@ -2915,6 +2919,11 @@ Utilities Accept a coroutine as *app* parameter. + .. versionadded:: 3.9 + + Support handler_cancellation parameter (this was the default behaviour + in aiohttp <3.7). + Constants --------- diff --git a/tests/test_web_server.py b/tests/test_web_server.py index b97e0fa7b64..3e7eff2ad8c 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,5 +1,6 @@ # type: ignore import asyncio +from contextlib import suppress from typing import Any from unittest import mock @@ -207,3 +208,80 @@ async def handler(request): ) logger.exception.assert_called_with("Error handling request", exc_info=exc) + + +async def test_handler_cancellation(aiohttp_unused_port) -> None: + event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app, handler_cancellation=True) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + assert runner.server.handler_cancellation, "Flag was not propagated" + + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +async def test_no_handler_cancellation(aiohttp_unused_port) -> None: + timeout_event = asyncio.Event() + done_event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal done_event, timeout_event + await asyncio.wait_for(timeout_event.wait(), timeout=5) + done_event.set() + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + await asyncio.sleep(0.1) + timeout_event.set() + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(done_event.wait(), timeout=1) + assert done_event.is_set() + finally: + await asyncio.gather(runner.shutdown(), site.stop())