Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cleanup pending route handlers on close #1412

Merged
merged 10 commits into from Jul 7, 2022
17 changes: 14 additions & 3 deletions playwright/_impl/_browser_context.py
Expand Up @@ -48,6 +48,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
HarRecordingMetadata,
RouteFromHarNotFoundPolicy,
RouteHandler,
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
self._request: APIRequestContext = from_channel(
initializer["APIRequestContext"]
)
self._background_task_tracker: BackgroundTaskTracker = BackgroundTaskTracker()
self._channel.on(
"bindingCall",
lambda params: self._on_binding(from_channel(params["binding"])),
Expand All @@ -113,7 +115,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._background_task_tracker.create_task(
self._on_route(
from_channel(params.get("route")),
from_channel(params.get("request")),
Expand Down Expand Up @@ -163,8 +165,14 @@ def __init__(
),
)
self._closed_future: asyncio.Future = asyncio.Future()

def _on_close(_: Any) -> None:
self._background_task_tracker.close()
self._closed_future.set_result(True)

self.once(
self.Events.Close, lambda context: self._closed_future.set_result(True)
self.Events.Close,
_on_close,
)

def __repr__(self) -> str:
Expand All @@ -187,7 +195,10 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
try:
await self._disable_interception()
except Exception:
pass
if handled:
return
await route._internal_continue(is_internal=True)
Expand Down
19 changes: 19 additions & 0 deletions playwright/_impl/_helper.py
Expand Up @@ -362,3 +362,22 @@ def is_file_payload(value: Optional[Any]) -> bool:
and "mimeType" in value
and "buffer" in value
)


class BackgroundTaskTracker:
def __init__(self) -> None:
self._pending_tasks: List[asyncio.Task] = []

def create_task(self, coro: Coroutine) -> asyncio.Task:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating the task here, I'd pass the task over and create it externally.

Copy link
Member Author

@rwoll rwoll Jul 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the advantage of creating it externally? We need to tightly control background tasks, so it seems safest to encapsulate all that's needed to do so like we have done here.

i.e. Don't use asyncio.create_task,—this will create leaks. Use BackgroundTaskTracker.create_task and we'll handle everything for you (like adding a done callback, adding it to pending tasks, and cancelling it eventually).

The thing that breaks us is calling asyncio.create_task and then not tracking it, so we replace the call that you need to use to create the task, and ensure it's being registered and cleaned up properly.

task = asyncio.create_task(coro)
task.add_done_callback(lambda task: self._pending_tasks.remove(task))
self._pending_tasks.append(task)
return task

def close(self) -> None:
try:
for task in self._pending_tasks:
if not task.done():
task.cancel()
except Exception:
rwoll marked this conversation as resolved.
Show resolved Hide resolved
pass
3 changes: 2 additions & 1 deletion playwright/_impl/_network.py
Expand Up @@ -223,7 +223,8 @@ def _report_handled(self, done: bool) -> None:
chain = self._handling_future
assert chain
self._handling_future = None
chain.set_result(done)
if not chain.done():
rwoll marked this conversation as resolved.
Show resolved Hide resolved
chain.set_result(done)

def _check_not_handled(self) -> None:
if not self._handling_future:
Expand Down
9 changes: 7 additions & 2 deletions playwright/_impl/_page.py
Expand Up @@ -55,6 +55,7 @@
from playwright._impl._frame import Frame
from playwright._impl._har_router import HarRouter
from playwright._impl._helper import (
BackgroundTaskTracker,
ColorScheme,
DocumentLoadState,
ForcedColors,
Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self._browser_context._timeout_settings
)
self._video: Optional[Video] = None
self._background_task_tracker = BackgroundTaskTracker()
self._opener = cast("Page", from_nullable_channel(initializer.get("opener")))

self._channel.on(
Expand Down Expand Up @@ -192,7 +194,7 @@ def __init__(
)
self._channel.on(
"route",
lambda params: asyncio.create_task(
lambda params: self._browser_context._background_task_tracker.create_task(
self._on_route(
from_channel(params["route"]), from_channel(params["request"])
)
Expand Down Expand Up @@ -246,7 +248,10 @@ async def _on_route(self, route: Route, request: Request) -> None:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
try:
await self._disable_interception()
except Exception:
pass
if handled:
return
await self._browser_context._on_route(route, request)
Expand Down
17 changes: 17 additions & 0 deletions tests/async/test_browsercontext_request_intercept.py
Expand Up @@ -174,3 +174,20 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await context.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
rwoll marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
pass
await context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
7 changes: 7 additions & 0 deletions tests/async/test_har.py
Expand Up @@ -503,6 +503,7 @@ async def test_should_round_trip_har_zip(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_round_trip_har_with_post_data(
Expand Down Expand Up @@ -536,6 +537,7 @@ async def test_should_round_trip_har_with_post_data(
assert await page_2.evaluate(fetch_function, "3") == "3"
with pytest.raises(Exception):
await page_2.evaluate(fetch_function, "4")
await context_2.close()


async def test_should_disambiguate_by_header(
Expand Down Expand Up @@ -578,6 +580,7 @@ async def test_should_disambiguate_by_header(
assert await page_2.evaluate(fetch_function, "baz2") == "baz2"
assert await page_2.evaluate(fetch_function, "baz3") == "baz3"
assert await page_2.evaluate(fetch_function, "baz4") == "baz1"
await context_2.close()


async def test_should_produce_extracted_zip(
Expand Down Expand Up @@ -605,6 +608,7 @@ async def test_should_produce_extracted_zip(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_har_zip_for_context(
Expand All @@ -627,6 +631,7 @@ async def test_should_update_har_zip_for_context(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_har_zip_for_page(
Expand All @@ -649,6 +654,7 @@ async def test_should_update_har_zip_for_page(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()


async def test_should_update_extracted_har_zip_for_page(
Expand All @@ -675,3 +681,4 @@ async def test_should_update_extracted_har_zip_for_page(
await expect(page_2.locator("body")).to_have_css(
"background-color", "rgb(255, 192, 203)"
)
await context_2.close()
19 changes: 18 additions & 1 deletion tests/async/test_request_intercept.py
Expand Up @@ -17,7 +17,7 @@

from twisted.web import http

from playwright.async_api import Page, Route
from playwright.async_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -168,3 +168,20 @@ async def test_should_give_access_to_the_intercepted_response_body(
route.fulfill(response=response),
eval_task,
)


async def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
async def handle(r: Route):
pass

await page.route("**", handle)
try:
await page.goto("https://example.com", timeout=700)
except Exception:
pass
await context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
15 changes: 15 additions & 0 deletions tests/sync/test_browsercontext_request_intercept.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http
Expand Down Expand Up @@ -121,3 +122,17 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
context.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)
6 changes: 6 additions & 0 deletions tests/sync/test_har.py
Expand Up @@ -471,6 +471,7 @@ def test_should_round_trip_har_with_post_data(
assert page_2.evaluate(fetch_function, "3") == "3"
with pytest.raises(Exception):
page_2.evaluate(fetch_function, "4")
context_2.close()


def test_should_disambiguate_by_header(
Expand Down Expand Up @@ -512,6 +513,7 @@ def test_should_disambiguate_by_header(
assert page_2.evaluate(fetch_function, "baz2") == "baz2"
assert page_2.evaluate(fetch_function, "baz3") == "baz3"
assert page_2.evaluate(fetch_function, "baz4") == "baz1"
context_2.close()


def test_should_produce_extracted_zip(
Expand All @@ -537,6 +539,7 @@ def test_should_produce_extracted_zip(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_har_zip_for_context(
Expand All @@ -557,6 +560,7 @@ def test_should_update_har_zip_for_context(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_har_zip_for_page(
Expand All @@ -577,6 +581,7 @@ def test_should_update_har_zip_for_page(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()


def test_should_update_extracted_har_zip_for_page(
Expand All @@ -601,3 +606,4 @@ def test_should_update_extracted_har_zip_for_page(
page_2.goto(server.PREFIX + "/one-style.html")
assert "hello, world!" in page_2.content()
expect(page_2.locator("body")).to_have_css("background-color", "rgb(255, 192, 203)")
context_2.close()
17 changes: 16 additions & 1 deletion tests/sync/test_request_intercept.py
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from pathlib import Path

from twisted.web import http

from playwright.sync_api import Page, Route
from playwright.sync_api import BrowserContext, Page, Route
from tests.server import Server


Expand Down Expand Up @@ -115,3 +116,17 @@ def handle_route(route: Route) -> None:
assert request.uri.decode() == "/title.html"
original = (assetdir / "title.html").read_text()
assert response.text() == original


def test_should_cleanup_route_handlers_after_context_close(
context: BrowserContext, page: Page
) -> None:
page.route("**", lambda r: None)
try:
page.goto("https://example.com", timeout=700)
except Exception:
pass
context.close()

for task in asyncio.all_tasks():
assert "_on_route" not in str(task)