From 38b9ec51e52fdcab11bbe322cc66392c599ca183 Mon Sep 17 00:00:00 2001 From: Mosquito Date: Sun, 11 Dec 2022 19:22:21 +0300 Subject: [PATCH] Added a configuration flag for enable request task handler cancelling when client connection closing. (#7056) ## Related to #6719 #6727. Added a configuration flag for enable request task handler cancelling when client connection closing. After changes in version 3.8.3, there is no longer any way to enable this behaviour. In our services, we want to handle protocol-level errors, for example for canceling the execution of a heavy query in the DBMS if the user's connection is broken. Now I created this PR in order to discuss my solution, of course if I did everything well I will add tests changelog, etc. ## I guess this breakdown can be solved using the configuration flag that is passed to the Server instance. Of course `AppRunner` and `SiteRunner` can pass this through `**kwargs` too. ## Related issue number #6719 ## Checklist - [ ] I think the code is well written - [ ] Unit tests for the changes exist - [ ] Documentation reflects the changes - [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is <Name> <Surname>. * Please keep alphabetical order, the file is sorted by names. - [ ] Add a new news fragment into the `CHANGES` folder * name it `.` for example (588.bugfix) * if you don't have an `issue_id` change it to the pr id after creating the pr * ensure type is one of the following: * `.feature`: Signifying a new feature. * `.bugfix`: Signifying a bug fix. * `.doc`: Signifying a documentation improvement. * `.removal`: Signifying a deprecation or removal of public API. * `.misc`: A ticket has been closed, but it is not of interest to users. * Make sure to use full sentences with correct case and punctuation, for example: "Fix issue with non-ascii contents in doctest text files." Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sam Bull --- CHANGES/7056.feature | 1 + aiohttp/web.py | 4 +++ aiohttp/web_protocol.py | 6 ++++ aiohttp/web_server.py | 2 ++ docs/web_reference.rst | 11 +++++- tests/test_web_server.py | 78 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 CHANGES/7056.feature diff --git a/CHANGES/7056.feature b/CHANGES/7056.feature new file mode 100644 index 00000000000..102fb4d7938 --- /dev/null +++ b/CHANGES/7056.feature @@ -0,0 +1 @@ +Added ``handler_cancellation`` parameter to cancel web handler on client disconnection. -- by :user:`mosquito` diff --git a/aiohttp/web.py b/aiohttp/web.py index e3e75779c6f..69d46bbf49f 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -307,6 +307,7 @@ async def _run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, ) -> None: # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): @@ -321,6 +322,7 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + handler_cancellation=handler_cancellation, ) await runner.setup() @@ -481,6 +483,7 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" @@ -513,6 +516,7 @@ def run_app( handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, + handler_cancellation=handler_cancellation, ) ) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 45b6f423fc1..27c815a4461 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -313,6 +313,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: super().connection_lost(exc) + # Grab value before setting _manager to None. + handler_cancellation = self._manager.handler_cancellation + self._manager = None self._force_close = True self._request_factory = None @@ -330,6 +333,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._waiter is not None: self._waiter.cancel() + if handler_cancellation and self._task_handler is not None: + self._task_handler.cancel() + self._task_handler = None if self._payload_parser is not None: diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 40211463f37..a3d658afbff 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -19,6 +19,7 @@ def __init__( *, request_factory: Optional[_RequestFactory] = None, debug: Optional[bool] = None, + handler_cancellation: bool = False, **kwargs: Any, ) -> None: if debug is not None: @@ -33,6 +34,7 @@ def __init__( self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request + self.handler_cancellation = handler_cancellation @property def connections(self) -> List[RequestHandler]: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index f79dd050c3b..44d9b12fcdd 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2809,7 +2809,8 @@ Utilities access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ - reuse_port=None) + reuse_port=None, \ + handler_cancellation=False) A high-level function for running an application, serving it until keyboard interrupt and performing a @@ -2905,6 +2906,9 @@ Utilities this flag when being created. This option is not supported on Windows. + :param bool handler_cancellation: cancels the web handler task if the client + drops the connection. + .. versionadded:: 3.0 Support *access_log_class* parameter. @@ -2915,6 +2919,11 @@ Utilities Accept a coroutine as *app* parameter. + .. versionadded:: 3.9 + + Support handler_cancellation parameter (this was the default behaviour + in aiohttp <3.7). + Constants --------- diff --git a/tests/test_web_server.py b/tests/test_web_server.py index b97e0fa7b64..3e7eff2ad8c 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,5 +1,6 @@ # type: ignore import asyncio +from contextlib import suppress from typing import Any from unittest import mock @@ -207,3 +208,80 @@ async def handler(request): ) logger.exception.assert_called_with("Error handling request", exc_info=exc) + + +async def test_handler_cancellation(aiohttp_unused_port) -> None: + event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app, handler_cancellation=True) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + assert runner.server.handler_cancellation, "Flag was not propagated" + + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +async def test_no_handler_cancellation(aiohttp_unused_port) -> None: + timeout_event = asyncio.Event() + done_event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal done_event, timeout_event + await asyncio.wait_for(timeout_event.wait(), timeout=5) + done_event.set() + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + await asyncio.sleep(0.1) + timeout_event.set() + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(done_event.wait(), timeout=1) + assert done_event.is_set() + finally: + await asyncio.gather(runner.shutdown(), site.stop())