diff --git a/playwright/_impl/_browser_context.py b/playwright/_impl/_browser_context.py index 844955322..94f9d1dee 100644 --- a/playwright/_impl/_browser_context.py +++ b/playwright/_impl/_browser_context.py @@ -48,6 +48,7 @@ from playwright._impl._frame import Frame from playwright._impl._har_router import HarRouter from playwright._impl._helper import ( + BackgroundTaskTracker, HarRecordingMetadata, RouteFromHarNotFoundPolicy, RouteHandler, @@ -103,6 +104,7 @@ def __init__( self._request: APIRequestContext = from_channel( initializer["APIRequestContext"] ) + self._background_task_tracker: BackgroundTaskTracker = BackgroundTaskTracker() self._channel.on( "bindingCall", lambda params: self._on_binding(from_channel(params["binding"])), @@ -113,7 +115,7 @@ def __init__( ) self._channel.on( "route", - lambda params: asyncio.create_task( + lambda params: self._background_task_tracker.create_task( self._on_route( from_channel(params.get("route")), from_channel(params.get("request")), @@ -163,8 +165,14 @@ def __init__( ), ) self._closed_future: asyncio.Future = asyncio.Future() + + def _on_close(_: Any) -> None: + self._background_task_tracker.close() + self._closed_future.set_result(True) + self.once( - self.Events.Close, lambda context: self._closed_future.set_result(True) + self.Events.Close, + _on_close, ) def __repr__(self) -> str: @@ -187,7 +195,7 @@ async def _on_route(self, route: Route, request: Request) -> None: handled = await route_handler.handle(route, request) finally: if len(self._routes) == 0: - asyncio.create_task(self._disable_interception()) + await self._disable_interception() if handled: return await route._internal_continue(is_internal=True) diff --git a/playwright/_impl/_helper.py b/playwright/_impl/_helper.py index fb2295298..f59ac8886 100644 --- a/playwright/_impl/_helper.py +++ b/playwright/_impl/_helper.py @@ -362,3 +362,21 @@ def is_file_payload(value: Optional[Any]) -> bool: and "mimeType" in value and "buffer" in value ) + + +class BackgroundTaskTracker: + def __init__(self) -> None: + self._pending_tasks: List[asyncio.Task] = [] + + def create_task(self, coro: Coroutine) -> asyncio.Task: + task = asyncio.create_task(coro) + self._pending_tasks.append(task) + return task + + def close(self) -> None: + try: + for task in self._pending_tasks: + if not task.done(): + task.cancel() + except Exception: + pass diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index ddb41aa12..318683348 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -55,6 +55,7 @@ from playwright._impl._frame import Frame from playwright._impl._har_router import HarRouter from playwright._impl._helper import ( + BackgroundTaskTracker, ColorScheme, DocumentLoadState, ForcedColors, @@ -151,6 +152,7 @@ def __init__( self._browser_context._timeout_settings ) self._video: Optional[Video] = None + self._background_task_tracker = BackgroundTaskTracker() self._opener = cast("Page", from_nullable_channel(initializer.get("opener"))) self._channel.on( @@ -192,7 +194,7 @@ def __init__( ) self._channel.on( "route", - lambda params: asyncio.create_task( + lambda params: self._background_task_tracker.create_task( self._on_route( from_channel(params["route"]), from_channel(params["request"]) ) @@ -209,11 +211,15 @@ def __init__( "worker", lambda params: self._on_worker(from_channel(params["worker"])) ) self._closed_or_crashed_future: asyncio.Future = asyncio.Future() + + def _on_close(_: Any) -> None: + self._background_task_tracker.close() + if not self._closed_or_crashed_future.done(): + self._closed_or_crashed_future.set_result(True) + self.on( Page.Events.Close, - lambda _: self._closed_or_crashed_future.set_result(True) - if not self._closed_or_crashed_future.done() - else None, + _on_close, ) self.on( Page.Events.Crash, @@ -246,7 +252,7 @@ async def _on_route(self, route: Route, request: Request) -> None: handled = await route_handler.handle(route, request) finally: if len(self._routes) == 0: - asyncio.create_task(self._disable_interception()) + await self._disable_interception() if handled: return await self._browser_context._on_route(route, request) diff --git a/tests/async/test_browsercontext_request_intercept.py b/tests/async/test_browsercontext_request_intercept.py index 763073df0..1d3e33945 100644 --- a/tests/async/test_browsercontext_request_intercept.py +++ b/tests/async/test_browsercontext_request_intercept.py @@ -174,3 +174,20 @@ async def test_should_give_access_to_the_intercepted_response_body( route.fulfill(response=response), eval_task, ) + + +async def test_should_cleanup_route_handlers_after_context_close( + context: BrowserContext, page: Page +) -> None: + async def handle(r: Route): + pass + + await context.route("**", handle) + try: + await page.goto("https://example.com", timeout=700) + except Exception: + pass + await context.close() + assert len(asyncio.all_tasks()) == 2 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task) diff --git a/tests/async/test_request_intercept.py b/tests/async/test_request_intercept.py index 39ccf3d3f..e90e964cc 100644 --- a/tests/async/test_request_intercept.py +++ b/tests/async/test_request_intercept.py @@ -17,7 +17,7 @@ from twisted.web import http -from playwright.async_api import Page, Route +from playwright.async_api import BrowserContext, Page, Route from tests.server import Server @@ -168,3 +168,37 @@ async def test_should_give_access_to_the_intercepted_response_body( route.fulfill(response=response), eval_task, ) + + +async def test_should_cleanup_route_handlers_after_page_close( + context: BrowserContext, page: Page +) -> None: + async def handle(r: Route): + pass + + await page.route("**", handle) + try: + await page.goto("https://example.com", timeout=700) + except Exception: + pass + await page.close() + assert len(asyncio.all_tasks()) == 2 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task) + + +async def test_should_cleanup_route_handlers_after_context_close( + context: BrowserContext, page: Page +) -> None: + async def handle(r: Route): + pass + + await page.route("**", handle) + try: + await page.goto("https://example.com", timeout=700) + except Exception: + pass + await context.close() + assert len(asyncio.all_tasks()) == 2 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task) diff --git a/tests/sync/test_browsercontext_request_intercept.py b/tests/sync/test_browsercontext_request_intercept.py index b136038ec..39510778f 100644 --- a/tests/sync/test_browsercontext_request_intercept.py +++ b/tests/sync/test_browsercontext_request_intercept.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from pathlib import Path from twisted.web import http @@ -121,3 +122,17 @@ def handle_route(route: Route) -> None: assert request.uri.decode() == "/title.html" original = (assetdir / "title.html").read_text() assert response.text() == original + + +def test_should_cleanup_route_handlers_after_context_close( + context: BrowserContext, page: Page +) -> None: + context.route("**", lambda r: None) + try: + page.goto("https://example.com", timeout=700) + except Exception: + pass + context.close() + assert len(asyncio.all_tasks()) == 1 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task) diff --git a/tests/sync/test_request_intercept.py b/tests/sync/test_request_intercept.py index dc714e832..f82399606 100644 --- a/tests/sync/test_request_intercept.py +++ b/tests/sync/test_request_intercept.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from pathlib import Path from twisted.web import http -from playwright.sync_api import Page, Route +from playwright.sync_api import BrowserContext, Page, Route from tests.server import Server @@ -115,3 +116,31 @@ def handle_route(route: Route) -> None: assert request.uri.decode() == "/title.html" original = (assetdir / "title.html").read_text() assert response.text() == original + + +def test_should_cleanup_route_handlers_after_page_close( + context: BrowserContext, page: Page +) -> None: + page.route("**", lambda r: None) + try: + page.goto("https://example.com", timeout=700) + except Exception: + pass + page.close() + assert len(asyncio.all_tasks()) == 1 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task) + + +def test_should_cleanup_route_handlers_after_context_close( + context: BrowserContext, page: Page +) -> None: + page.route("**", lambda r: None) + try: + page.goto("https://example.com", timeout=700) + except Exception: + pass + context.close() + assert len(asyncio.all_tasks()) == 1 + for task in asyncio.all_tasks(): + assert "_on_route" not in str(task)