Skip to content

Commit

Permalink
ensure TestClient requests run in the same EventLoop as lifespan
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Jun 23, 2021
1 parent 3498c9e commit 061f669
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 37 deletions.
87 changes: 54 additions & 33 deletions starlette/testclient.py
Expand Up @@ -12,17 +12,23 @@
from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit

import anyio
import anyio.abc
import requests
from anyio.streams.stapled import StapledObjectStream

from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict # pragma: no cover
from typing import Protocol, TypedDict # pragma: no cover
else: # pragma: no cover
from typing_extensions import TypedDict # pragma: no cover
from typing_extensions import Protocol, TypedDict # pragma: no cover


class _PortalFactoryType(Protocol):
def __call__(self) -> typing.ContextManager[anyio.abc.BlockingPortal]:
... # pragma: no cover


# Annotations for `Session.request()`
Cookies = typing.Union[
Expand Down Expand Up @@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.async_backend = async_backend
self.portal_factory = portal_factory

def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
Expand Down Expand Up @@ -162,7 +168,7 @@ def send(
"server": [host, port],
"subprotocols": subprotocols,
}
session = WebSocketTestSession(self.app, scope, self.async_backend)
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)

scope = {
Expand Down Expand Up @@ -252,7 +258,7 @@ async def send(message: Message) -> None:
context = message["context"]

try:
with anyio.start_blocking_portal(**self.async_backend) as portal:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
Expand Down Expand Up @@ -285,20 +291,18 @@ def __init__(
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.async_backend = async_backend
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()

def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.portal = self.exit_stack.enter_context(self.portal_factory())

try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
Expand Down Expand Up @@ -396,6 +400,7 @@ def receive_json(self, mode: str = "text") -> typing.Any:
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None

def __init__(
self,
Expand All @@ -418,7 +423,7 @@ def __init__(
asgi_app = _WrapASGI2(app) #  type: ignore
adapter = _ASGIAdapter(
asgi_app,
self.async_backend,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
Expand All @@ -430,6 +435,16 @@ def __init__(
self.app = asgi_app
self.base_url = base_url

@contextlib.contextmanager
def _portal_factory(
self,
) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.start_blocking_portal(**self.async_backend) as portal:
yield portal

def request( # type: ignore
self,
method: str,
Expand Down Expand Up @@ -490,29 +505,35 @@ def websocket_connect(
return session

def __enter__(self) -> "TestClient":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
try:
self.task = self.portal.start_task_soon(self.lifespan)
self.portal.call(self.wait_startup)
except Exception:
self.exit_stack.close()
raise
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)

self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)
self.portal = portal

@stack.callback
def reset_portal() -> None:
self.portal = None

@stack.callback
def wait_shutdown() -> None:
portal.call(self.wait_shutdown)

self.exit_stack = stack.pop_all()

return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.portal.call(self.wait_shutdown)
finally:
self.exit_stack.close()
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
Expand Down
77 changes: 73 additions & 4 deletions tests/test_testclient.py
@@ -1,3 +1,5 @@
import itertools

import anyio
import pytest

Expand All @@ -14,15 +16,41 @@ def mock_service_endpoint(request):
return JSONResponse({"mock": "example"})


def create_app(test_client_factory):
_identity_runvar: anyio.lowlevel.RunVar[int] = anyio.lowlevel.RunVar("_identity_runvar")


def get_identity(counter):
try:
return _identity_runvar.get()
except LookupError:
token = next(counter)
_identity_runvar.set(token)
return token


def create_app(test_client_factory, counter=itertools.count()):
app = Starlette()

@app.on_event("startup")
async def get_startup_thread():
app.startup_task = anyio.get_current_task().id
app.startup_loop = get_identity(counter)

@app.on_event("shutdown")
async def get_shutdown_thread():
app.shutdown_task = anyio.get_current_task().id
app.shutdown_loop = get_identity(counter)

@app.route("/")
def homepage(request):
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())

@app.route("/thread")
async def thread(request):
return JSONResponse(get_identity(counter))

return app


Expand All @@ -46,9 +74,50 @@ def test_use_testclient_in_endpoint(test_client_factory):
assert response.json() == {"mock": "example"}


def test_use_testclient_as_contextmanager(test_client_factory):
with test_client_factory(create_app(test_client_factory)):
pass
def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
"""
This test asserts a number of properties that are important for an
app level task_group
"""
app = create_app(test_client_factory, counter=itertools.count())
client = test_client_factory(app)

with client:
# within a TestClient context every async request runs in the same thread
assert client.get("/thread").json() == 0
assert client.get("/thread").json() == 0

# that thread is also the same as the lifespan thread
assert app.startup_loop == 0
assert app.shutdown_loop == 0

# lifespan events run in the same task, this is important because a task
# group must be entered and exited in the same task.
assert app.startup_task == app.shutdown_task

# outside the TestClient context, new requests continue to spawn in new
# eventloops in new threads
assert client.get("/thread").json() == 1
assert client.get("/thread").json() == 2

first_task = app.startup_task

with client:
# the TestClient context can be re-used, starting a new lifespan task
# in a new thread
assert client.get("/thread").json() == 3
assert client.get("/thread").json() == 3

assert app.startup_loop == 3
assert app.shutdown_loop == 3

# lifespan events still run in the same task, with the context but...
assert app.startup_task == app.shutdown_task

if anyio_backend_name != "asyncio":
# https://github.com/agronholm/anyio/issues/324
# ... the second TestClient context creates a new lifespan task.
assert first_task != app.startup_task


def test_error_on_startup(test_client_factory):
Expand Down

0 comments on commit 061f669

Please sign in to comment.