From 3d1c00d4f0a7012b24b091851c23f31c41ffd250 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 23 Jun 2021 23:17:25 +0100 Subject: [PATCH] ensure TestClient requests run in the same EventLoop as lifespan --- starlette/testclient.py | 87 +++++++++++++++++++++++++--------------- tests/test_testclient.py | 81 +++++++++++++++++++++++++++++++++++-- 2 files changed, 131 insertions(+), 37 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 21f4dba90d..d58c0bc8b4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,7 +12,7 @@ from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit -import anyio +import anyio.abc import requests from anyio.streams.stapled import StapledObjectStream @@ -20,9 +20,15 @@ from starlette.websockets import WebSocketDisconnect if sys.version_info >= (3, 8): # pragma: no cover - from typing import TypedDict # pragma: no cover + from typing import Protocol, TypedDict # pragma: no cover else: # pragma: no cover - from typing_extensions import TypedDict # pragma: no cover + from typing_extensions import Protocol, TypedDict # pragma: no cover + + +class _PortalFactoryType(Protocol): + def __call__(self) -> typing.ContextManager[anyio.abc.BlockingPortal]: + ... # pragma: no cover + # Annotations for `Session.request()` Cookies = typing.Union[ @@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, - async_backend: _AsyncBackend, + portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path - self.async_backend = async_backend + self.portal_factory = portal_factory def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any @@ -162,7 +168,7 @@ def send( "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope, self.async_backend) + session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { @@ -252,7 +258,7 @@ async def send(message: Message) -> None: context = message["context"] try: - with anyio.start_blocking_portal(**self.async_backend) as portal: + with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: @@ -285,20 +291,18 @@ def __init__( self, app: ASGI3App, scope: Scope, - async_backend: _AsyncBackend, + portal_factory: _PortalFactoryType, ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None - self.async_backend = async_backend + self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() def __enter__(self) -> "WebSocketTestSession": self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) + self.portal = self.exit_stack.enter_context(self.portal_factory()) try: _: "Future[None]" = self.portal.start_task_soon(self._run) @@ -396,6 +400,7 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. task: "Future[None]" + portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, @@ -418,7 +423,7 @@ def __init__( asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, - self.async_backend, + portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) @@ -430,6 +435,16 @@ def __init__( self.app = asgi_app self.base_url = base_url + @contextlib.contextmanager + def _portal_factory( + self, + ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.start_blocking_portal(**self.async_backend) as portal: + yield portal + def request( # type: ignore self, method: str, @@ -490,29 +505,35 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) - self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - try: - self.task = self.portal.start_task_soon(self.lifespan) - self.portal.call(self.wait_startup) - except Exception: - self.exit_stack.close() - raise + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + self.portal = portal + + @stack.callback + def reset_portal() -> None: + self.portal = None + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - try: - self.portal.call(self.wait_shutdown) - finally: - self.exit_stack.close() + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} diff --git a/tests/test_testclient.py b/tests/test_testclient.py index fd96f69a7e..9c0091b073 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,3 +1,5 @@ +import itertools + import anyio import pytest @@ -14,15 +16,41 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) -def create_app(test_client_factory): +_identity_runvar: anyio.lowlevel.RunVar[int] = anyio.lowlevel.RunVar("_identity_runvar") + + +def get_identity(counter): + try: + return _identity_runvar.get() + except LookupError: + token = next(counter) + _identity_runvar.set(token) + return token + + +def create_app(test_client_factory, counter=itertools.count()): app = Starlette() + @app.on_event("startup") + async def get_startup_thread(): + app.startup_task = anyio.get_current_task().id + app.startup_loop = get_identity(counter) + + @app.on_event("shutdown") + async def get_shutdown_thread(): + app.shutdown_task = anyio.get_current_task().id + app.shutdown_loop = get_identity(counter) + @app.route("/") def homepage(request): client = test_client_factory(mock_service) response = client.get("/") return JSONResponse(response.json()) + @app.route("/thread") + async def thread(request): + return JSONResponse(get_identity(counter)) + return app @@ -46,9 +74,54 @@ def test_use_testclient_in_endpoint(test_client_factory): assert response.json() == {"mock": "example"} -def test_use_testclient_as_contextmanager(test_client_factory): - with test_client_factory(create_app(test_client_factory)): - pass +def unique(*items): + return len(set(items)) == len(items) + + +def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): + """ + This test asserts a number of properties that are important for an + app level task_group + """ + app = create_app(test_client_factory, counter=itertools.count()) + client = test_client_factory(app) + + with client: + # within a TestClient context every async request runs in the same thread + assert client.get("/thread").json() == 0 + assert client.get("/thread").json() == 0 + + # that thread is also the same as the lifespan thread + assert app.startup_loop == 0 + assert app.shutdown_loop == 0 + + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert app.startup_task == app.shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # eventloops in new threads + assert client.get("/thread").json() == 1 + assert client.get("/thread").json() == 2 + + first_task = app.startup_task + + with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert client.get("/thread").json() == 3 + assert client.get("/thread").json() == 3 + + assert app.startup_loop == 3 + assert app.shutdown_loop == 3 + + # lifespan events still run in the same task, with the context but... + assert app.startup_task == app.shutdown_task + + if anyio_backend_name != "asyncio": + # https://github.com/agronholm/anyio/issues/324 + # ... the second TestClient context creates a new lifespan task. + assert first_task != app.startup_task def test_error_on_startup(test_client_factory):