Skip to content

Commit

Permalink
fix: cleanup pending route handlers on close
Browse files Browse the repository at this point in the history
  • Loading branch information
rwoll committed Jul 7, 2022
1 parent 8820f30 commit 71c80af
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 10 deletions.
14 changes: 11 additions & 3 deletions playwright/_impl/_browser_context.py
Expand Up @@ -48,6 +48,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
HarRecordingMetadata,
RouteFromHarNotFoundPolicy,
RouteHandler,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
self._request: APIRequestContext = from_channel(
initializer["APIRequestContext"]
)
self._background_task_tracker: BackgroundTaskTracker = BackgroundTaskTracker()
self._channel.on(
"bindingCall",
lambda params: self._on_binding(from_channel(params["binding"])),
Expand All @@ -113,7 +115,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._background_task_tracker.create_task(
self._on_route(
from_channel(params.get("route")),
from_channel(params.get("request")),
Expand Down Expand Up @@ -163,8 +165,14 @@ def __init__(
),
)
self._closed_future: asyncio.Future = asyncio.Future()

def _on_close(_: Any) -> None:
self._background_task_tracker.close()
self._closed_future.set_result(True)

self.once(
self.Events.Close, lambda context: self._closed_future.set_result(True)
self.Events.Close,
_on_close,
)

def __repr__(self) -> str:
Expand All @@ -187,7 +195,7 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
await self._disable_interception()
if handled:
return
await route._internal_continue(is_internal=True)
Expand Down
18 changes: 18 additions & 0 deletions playwright/_impl/_helper.py
Expand Up @@ -362,3 +362,21 @@ def is_file_payload(value: Optional[Any]) -> bool:
and "mimeType" in value
and "buffer" in value
)


class BackgroundTaskTracker:
def __init__(self) -> None:
self._pending_tasks: List[asyncio.Task] = []

def create_task(self, coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro)
self._pending_tasks.append(task)
return task

def close(self) -> None:
try:
for task in self._pending_tasks:
if not task.done():
task.cancel()
except Exception:
pass
16 changes: 11 additions & 5 deletions playwright/_impl/_page.py
Expand Up @@ -55,6 +55,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
ColorScheme,
DocumentLoadState,
ForcedColors,
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self._browser_context._timeout_settings
)
self._video: Optional[Video] = None
self._background_task_tracker = BackgroundTaskTracker()
self._opener = cast("Page", from_nullable_channel(initializer.get("opener")))

self._channel.on(
Expand Down Expand Up @@ -192,7 +194,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._background_task_tracker.create_task(
self._on_route(
from_channel(params["route"]), from_channel(params["request"])
)
Expand All @@ -209,11 +211,15 @@ def __init__(
"worker", lambda params: self._on_worker(from_channel(params["worker"]))
)
self._closed_or_crashed_future: asyncio.Future = asyncio.Future()

def _on_close(_: Any) -> None:
self._background_task_tracker.close()
if not self._closed_or_crashed_future.done():
self._closed_or_crashed_future.set_result(True)

self.on(
Page.Events.Close,
lambda _: self._closed_or_crashed_future.set_result(True)
if not self._closed_or_crashed_future.done()
else None,
_on_close,
)
self.on(
Page.Events.Crash,
Expand Down Expand Up @@ -246,7 +252,7 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
await self._disable_interception()
if handled:
return
await self._browser_context._on_route(route, request)
Expand Down
17 changes: 17 additions & 0 deletions tests/async/test_browsercontext_request_intercept.py
Expand Up @@ -174,3 +174,20 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await context.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await context.close()
assert len(asyncio.all_tasks()) == 2
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
36 changes: 35 additions & 1 deletion tests/async/test_request_intercept.py
Expand Up @@ -17,7 +17,7 @@

from twisted.web import http

from playwright.async_api import Page, Route
from playwright.async_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -168,3 +168,37 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_page_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await page.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await page.close()
assert len(asyncio.all_tasks()) == 2
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await page.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await context.close()
assert len(asyncio.all_tasks()) == 2
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
15 changes: 15 additions & 0 deletions tests/sync/test_browsercontext_request_intercept.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http
Expand Down Expand Up @@ -121,3 +122,17 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
context.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()
assert len(asyncio.all_tasks()) == 1
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
31 changes: 30 additions & 1 deletion tests/sync/test_request_intercept.py
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http

from playwright.sync_api import Page, Route
from playwright.sync_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -115,3 +116,31 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_page_close(
context: BrowserContext, page: Page
) -> None:
page.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
page.close()
assert len(asyncio.all_tasks()) == 1
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
page.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()
assert len(asyncio.all_tasks()) == 1
for task in asyncio.all_tasks():
assert "_on_route" not in str(task)

0 comments on commit 71c80af

Please sign in to comment.