Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure TestClient requests run in the same EventLoop as lifespan #1213

87 changes: 54 additions & 33 deletions starlette/testclient.py
Expand Up @@ -12,17 +12,23 @@
from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit

import anyio
import anyio.abc
import requests
from anyio.streams.stapled import StapledObjectStream

from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
from typing import Protocol, TypedDict
graingert marked this conversation as resolved.
Show resolved Hide resolved
else: # pragma: no cover
from typing_extensions import TypedDict
from typing_extensions import Protocol, TypedDict
graingert marked this conversation as resolved.
Show resolved Hide resolved


class _PortalFactoryType(Protocol):
def __call__(self) -> typing.ContextManager[anyio.abc.BlockingPortal]:
graingert marked this conversation as resolved.
Show resolved Hide resolved
... # pragma: no cover
graingert marked this conversation as resolved.
Show resolved Hide resolved


# Annotations for `Session.request()`
Cookies = typing.Union[
Expand Down Expand Up @@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now _AsyncBackend is only used in one place 😅

raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.async_backend = async_backend
self.portal_factory = portal_factory

def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
Expand Down Expand Up @@ -162,7 +168,7 @@ def send(
"server": [host, port],
"subprotocols": subprotocols,
}
session = WebSocketTestSession(self.app, scope, self.async_backend)
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)

scope = {
Expand Down Expand Up @@ -252,7 +258,7 @@ async def send(message: Message) -> None:
context = message["context"]

try:
with anyio.start_blocking_portal(**self.async_backend) as portal:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
Expand Down Expand Up @@ -285,20 +291,18 @@ def __init__(
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.async_backend = async_backend
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()

def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.portal = self.exit_stack.enter_context(self.portal_factory())

try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
Expand Down Expand Up @@ -396,6 +400,7 @@ def receive_json(self, mode: str = "text") -> typing.Any:
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None

def __init__(
self,
Expand All @@ -418,7 +423,7 @@ def __init__(
asgi_app = _WrapASGI2(app) #  type: ignore
adapter = _ASGIAdapter(
asgi_app,
self.async_backend,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
Expand All @@ -430,6 +435,16 @@ def __init__(
self.app = asgi_app
self.base_url = base_url

@contextlib.contextmanager
def _portal_factory(
self,
) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
uSpike marked this conversation as resolved.
Show resolved Hide resolved
if self.portal is not None:
yield self.portal
else:
with anyio.start_blocking_portal(**self.async_backend) as portal:
yield portal

def request( # type: ignore
self,
method: str,
Expand Down Expand Up @@ -490,29 +505,35 @@ def websocket_connect(
return session

def __enter__(self) -> "TestClient":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
try:
self.task = self.portal.start_task_soon(self.lifespan)
self.portal.call(self.wait_startup)
except Exception:
self.exit_stack.close()
raise
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
uSpike marked this conversation as resolved.
Show resolved Hide resolved
graingert marked this conversation as resolved.
Show resolved Hide resolved

self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
self.portal = portal
graingert marked this conversation as resolved.
Show resolved Hide resolved
graingert marked this conversation as resolved.
Show resolved Hide resolved

@stack.callback
def reset_portal() -> None:
self.portal = None

@stack.callback
def wait_shutdown() -> None:
portal.call(self.wait_shutdown)

self.exit_stack = stack.pop_all()

return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.portal.call(self.wait_shutdown)
finally:
self.exit_stack.close()
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
Expand Down
99 changes: 95 additions & 4 deletions tests/test_testclient.py
@@ -1,11 +1,22 @@
import asyncio
import itertools
import sys

import anyio
import pytest
import sniffio
import trio.lowlevel

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect

if sys.version_info >= (3, 7):
from asyncio import current_task as asyncio_current_task # pragma: no cover
else:
asyncio_current_task = asyncio.Task.current_task # pragma: no cover
JayH5 marked this conversation as resolved.
Show resolved Hide resolved

mock_service = Starlette()


Expand All @@ -14,15 +25,56 @@ def mock_service_endpoint(request):
return JSONResponse({"mock": "example"})


def create_app(test_client_factory):
_identity_runvar: anyio.lowlevel.RunVar[int] = anyio.lowlevel.RunVar("_identity_runvar")


def get_identity(counter):
try:
return _identity_runvar.get()
except LookupError:
token = next(counter)
_identity_runvar.set(token)
return token


def current_task():
# anyio's TaskInfo comparisons are invalid after their associated native
# task object is GC'd https://github.com/agronholm/anyio/issues/324
asynclib_name = sniffio.current_async_library()
if asynclib_name == "trio":
return trio.lowlevel.current_task()

if asynclib_name == "asyncio":
task = asyncio_current_task()
if task is None:
raise RuntimeError("must be called from a running task") # pragma: no cover
return task
raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover


def create_app(test_client_factory, counter=itertools.count()):
graingert marked this conversation as resolved.
Show resolved Hide resolved
app = Starlette()

@app.on_event("startup")
async def get_startup_thread():
graingert marked this conversation as resolved.
Show resolved Hide resolved
app.startup_task = current_task()
app.startup_loop = get_identity(counter)
graingert marked this conversation as resolved.
Show resolved Hide resolved

@app.on_event("shutdown")
async def get_shutdown_thread():
graingert marked this conversation as resolved.
Show resolved Hide resolved
app.shutdown_task = current_task()
app.shutdown_loop = get_identity(counter)

@app.route("/")
def homepage(request):
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())

@app.route("/loop_id")
async def loop_id(request):
return JSONResponse(get_identity(counter))

return app


Expand All @@ -46,9 +98,48 @@ def test_use_testclient_in_endpoint(test_client_factory):
assert response.json() == {"mock": "example"}


def test_use_testclient_as_contextmanager(test_client_factory):
with test_client_factory(create_app(test_client_factory)):
pass
def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
"""
This test asserts a number of properties that are important for an
app level task_group
"""
app = create_app(test_client_factory, counter=itertools.count())
client = test_client_factory(app)

with client:
# within a TestClient context every async request runs in the same thread
assert client.get("/loop_id").json() == 0
assert client.get("/loop_id").json() == 0

# that thread is also the same as the lifespan thread
assert app.startup_loop == 0
assert app.shutdown_loop == 0

# lifespan events run in the same task, this is important because a task
# group must be entered and exited in the same task.
assert app.startup_task is app.shutdown_task

# outside the TestClient context, new requests continue to spawn in new
# eventloops in new threads
assert client.get("/loop_id").json() == 1
assert client.get("/loop_id").json() == 2

first_task = app.startup_task

with client:
# the TestClient context can be re-used, starting a new lifespan task
# in a new thread
assert client.get("/loop_id").json() == 3
assert client.get("/loop_id").json() == 3

assert app.startup_loop == 3
assert app.shutdown_loop == 3

# lifespan events still run in the same task, with the context but...
assert app.startup_task is app.shutdown_task

# ... the second TestClient context creates a new lifespan task.
assert first_task is not app.startup_task


def test_error_on_startup(test_client_factory):
Expand Down