Skip to content

Commit

Permalink
fix: cleanup pending route handlers on close (#1412)
Browse files Browse the repository at this point in the history
Fixes #1402.
  • Loading branch information
rwoll committed Jul 7, 2022
1 parent 8820f30 commit c8d8f4a
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 8 deletions.
17 changes: 14 additions & 3 deletions playwright/_impl/_browser_context.py
Expand Up @@ -48,6 +48,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
HarRecordingMetadata,
RouteFromHarNotFoundPolicy,
RouteHandler,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
self._request: APIRequestContext = from_channel(
initializer["APIRequestContext"]
)
self._background_task_tracker: BackgroundTaskTracker = BackgroundTaskTracker()
self._channel.on(
"bindingCall",
lambda params: self._on_binding(from_channel(params["binding"])),
Expand All @@ -113,7 +115,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._background_task_tracker.create_task(
self._on_route(
from_channel(params.get("route")),
from_channel(params.get("request")),
Expand Down Expand Up @@ -163,8 +165,14 @@ def __init__(
),
)
self._closed_future: asyncio.Future = asyncio.Future()

def _on_close(_: Any) -> None:
self._background_task_tracker.close()
self._closed_future.set_result(True)

self.once(
self.Events.Close, lambda context: self._closed_future.set_result(True)
self.Events.Close,
_on_close,
)

def __repr__(self) -> str:
Expand All @@ -187,7 +195,10 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
try:
await self._disable_interception()
except Exception:
pass
if handled:
return
await route._internal_continue(is_internal=True)
Expand Down
16 changes: 16 additions & 0 deletions playwright/_impl/_helper.py
Expand Up @@ -362,3 +362,19 @@ def is_file_payload(value: Optional[Any]) -> bool:
and "mimeType" in value
and "buffer" in value
)


class BackgroundTaskTracker:
def __init__(self) -> None:
self._pending_tasks: List[asyncio.Task] = []

def create_task(self, coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro)
task.add_done_callback(lambda task: self._pending_tasks.remove(task))
self._pending_tasks.append(task)
return task

def close(self) -> None:
for task in self._pending_tasks:
if not task.done():
task.cancel()
3 changes: 2 additions & 1 deletion playwright/_impl/_network.py
Expand Up @@ -223,7 +223,8 @@ def _report_handled(self, done: bool) -> None:
chain = self._handling_future
assert chain
self._handling_future = None
chain.set_result(done)
if not chain.done():
chain.set_result(done)

def _check_not_handled(self) -> None:
if not self._handling_future:
Expand Down
9 changes: 7 additions & 2 deletions playwright/_impl/_page.py
Expand Up @@ -55,6 +55,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
ColorScheme,
DocumentLoadState,
ForcedColors,
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self._browser_context._timeout_settings
)
self._video: Optional[Video] = None
self._background_task_tracker = BackgroundTaskTracker()
self._opener = cast("Page", from_nullable_channel(initializer.get("opener")))

self._channel.on(
Expand Down Expand Up @@ -192,7 +194,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._browser_context._background_task_tracker.create_task(
self._on_route(
from_channel(params["route"]), from_channel(params["request"])
)
Expand Down Expand Up @@ -246,7 +248,10 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
try:
await self._disable_interception()
except Exception:
pass
if handled:
return
await self._browser_context._on_route(route, request)
Expand Down
17 changes: 17 additions & 0 deletions tests/async/test_browsercontext_request_intercept.py
Expand Up @@ -174,3 +174,20 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await context.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
7 changes: 7 additions & 0 deletions tests/async/test_har.py
Expand Up @@ -503,6 +503,7 @@ async def test_should_round_trip_har_zip(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_round_trip_har_with_post_data(
Expand Down Expand Up @@ -536,6 +537,7 @@ async def test_should_round_trip_har_with_post_data(
assert await page_2.evaluate(fetch_function, "3") == "3"
with pytest.raises(Exception):
await page_2.evaluate(fetch_function, "4")
await context_2.close()


async def test_should_disambiguate_by_header(
Expand Down Expand Up @@ -578,6 +580,7 @@ async def test_should_disambiguate_by_header(
assert await page_2.evaluate(fetch_function, "baz2") == "baz2"
assert await page_2.evaluate(fetch_function, "baz3") == "baz3"
assert await page_2.evaluate(fetch_function, "baz4") == "baz1"
await context_2.close()


async def test_should_produce_extracted_zip(
Expand Down Expand Up @@ -605,6 +608,7 @@ async def test_should_produce_extracted_zip(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_har_zip_for_context(
Expand All @@ -627,6 +631,7 @@ async def test_should_update_har_zip_for_context(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_har_zip_for_page(
Expand All @@ -649,6 +654,7 @@ async def test_should_update_har_zip_for_page(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_extracted_har_zip_for_page(
Expand All @@ -675,3 +681,4 @@ async def test_should_update_extracted_har_zip_for_page(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()
19 changes: 18 additions & 1 deletion tests/async/test_request_intercept.py
Expand Up @@ -17,7 +17,7 @@

from twisted.web import http

from playwright.async_api import Page, Route
from playwright.async_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -168,3 +168,20 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await page.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
15 changes: 15 additions & 0 deletions tests/sync/test_browsercontext_request_intercept.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http
Expand Down Expand Up @@ -121,3 +122,17 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
context.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
6 changes: 6 additions & 0 deletions tests/sync/test_har.py
Expand Up @@ -471,6 +471,7 @@ def test_should_round_trip_har_with_post_data(
assert page_2.evaluate(fetch_function, "3") == "3"
with pytest.raises(Exception):
page_2.evaluate(fetch_function, "4")
context_2.close()


def test_should_disambiguate_by_header(
Expand Down Expand Up @@ -512,6 +513,7 @@ def test_should_disambiguate_by_header(
assert page_2.evaluate(fetch_function, "baz2") == "baz2"
assert page_2.evaluate(fetch_function, "baz3") == "baz3"
assert page_2.evaluate(fetch_function, "baz4") == "baz1"
context_2.close()


def test_should_produce_extracted_zip(
Expand All @@ -537,6 +539,7 @@ def test_should_produce_extracted_zip(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_har_zip_for_context(
Expand All @@ -557,6 +560,7 @@ def test_should_update_har_zip_for_context(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_har_zip_for_page(
Expand All @@ -577,6 +581,7 @@ def test_should_update_har_zip_for_page(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_extracted_har_zip_for_page(
Expand All @@ -601,3 +606,4 @@ def test_should_update_extracted_har_zip_for_page(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()
17 changes: 16 additions & 1 deletion tests/sync/test_request_intercept.py
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http

from playwright.sync_api import Page, Route
from playwright.sync_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -115,3 +116,17 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
page.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)

0 comments on commit c8d8f4a

Please sign in to comment.