diff --git a/starlette/testclient.py b/starlette/testclient.py index 33bb410d0..7aa59fb9e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,7 +12,7 @@ from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit -import anyio +import anyio.abc import requests from anyio.streams.stapled import StapledObjectStream @@ -24,6 +24,12 @@ else: # pragma: no cover from typing_extensions import TypedDict + +_PortalFactoryType = typing.Callable[ + [], typing.ContextManager[anyio.abc.BlockingPortal] +] + + # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar @@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, - async_backend: _AsyncBackend, + portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path - self.async_backend = async_backend + self.portal_factory = portal_factory def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any @@ -162,7 +168,7 @@ def send( "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope, self.async_backend) + session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { @@ -252,7 +258,7 @@ async def send(message: Message) -> None: context = message["context"] try: - with anyio.start_blocking_portal(**self.async_backend) as portal: + with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: @@ -285,20 +291,18 @@ def __init__( self, app: ASGI3App, scope: Scope, - async_backend: _AsyncBackend, + portal_factory: _PortalFactoryType, ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None - self.async_backend = async_backend + self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() def __enter__(self) -> "WebSocketTestSession": self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) + self.portal = self.exit_stack.enter_context(self.portal_factory()) try: _: "Future[None]" = self.portal.start_task_soon(self._run) @@ -396,6 +400,7 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. task: "Future[None]" + portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, @@ -418,7 +423,7 @@ def __init__( asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, - self.async_backend, + portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) @@ -430,6 +435,16 @@ def __init__( self.app = asgi_app self.base_url = base_url + @contextlib.contextmanager + def _portal_factory( + self, + ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.start_blocking_portal(**self.async_backend) as portal: + yield portal + def request( # type: ignore self, method: str, @@ -490,29 +505,34 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) - self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - try: - self.task = self.portal.start_task_soon(self.lifespan) - self.portal.call(self.wait_startup) - except Exception: - self.exit_stack.close() - raise + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + @stack.callback + def reset_portal() -> None: + self.portal = None + + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - try: - self.portal.call(self.wait_shutdown) - finally: - self.exit_stack.close() + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} diff --git a/tests/test_testclient.py b/tests/test_testclient.py index fd96f69a7..57ea1c3db 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,11 +1,22 @@ +import asyncio +import itertools +import sys + import anyio import pytest +import sniffio +import trio.lowlevel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse from starlette.websockets import WebSocket, WebSocketDisconnect +if sys.version_info >= (3, 7): + from asyncio import current_task as asyncio_current_task # pragma: no cover +else: + asyncio_current_task = asyncio.Task.current_task # pragma: no cover + mock_service = Starlette() @@ -14,16 +25,19 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) -def create_app(test_client_factory): - app = Starlette() - - @app.route("/") - def homepage(request): - client = test_client_factory(mock_service) - response = client.get("/") - return JSONResponse(response.json()) +def current_task(): + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() - return app + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover startup_error_app = Starlette() @@ -41,14 +55,93 @@ def test_use_testclient_in_endpoint(test_client_factory): This is useful if we need to mock out other services, during tests or in development. """ - client = test_client_factory(create_app(test_client_factory)) + + app = Starlette() + + @app.route("/") + def homepage(request): + client = test_client_factory(mock_service) + response = client.get("/") + return JSONResponse(response.json()) + + client = test_client_factory(app) response = client.get("/") assert response.json() == {"mock": "example"} -def test_use_testclient_as_contextmanager(test_client_factory): - with test_client_factory(create_app(test_client_factory)): - pass +def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): + """ + This test asserts a number of properties that are important for an + app level task_group + """ + counter = itertools.count() + identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") + + def get_identity(): + try: + return identity_runvar.get() + except LookupError: + token = next(counter) + identity_runvar.set(token) + return token + + startup_task = object() + startup_loop = None + shutdown_task = object() + shutdown_loop = None + + async def lifespan_context(app): + nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop + + startup_task = current_task() + startup_loop = get_identity() + async with anyio.create_task_group() as app.task_group: + yield + shutdown_task = current_task() + shutdown_loop = get_identity() + + app = Starlette(lifespan=lifespan_context) + + @app.route("/loop_id") + async def loop_id(request): + return JSONResponse(get_identity()) + + client = test_client_factory(app) + + with client: + # within a TestClient context every async request runs in the same thread + assert client.get("/loop_id").json() == 0 + assert client.get("/loop_id").json() == 0 + + # that thread is also the same as the lifespan thread + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert startup_task is shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # eventloops in new threads + assert client.get("/loop_id").json() == 1 + assert client.get("/loop_id").json() == 2 + + first_task = startup_task + + with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert client.get("/loop_id").json() == 3 + assert client.get("/loop_id").json() == 3 + + assert startup_loop == 3 + assert shutdown_loop == 3 + + # lifespan events still run in the same task, with the context but... + assert startup_task is shutdown_task + + # ... the second TestClient context creates a new lifespan task. + assert first_task is not startup_task def test_error_on_startup(test_client_factory):