diff --git a/README.md b/README.md index 184eb480a..8eedea952 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, -which is ideal for building high performance asyncio services. +which is ideal for building high performance async services. It is production-ready, and gives you the following: @@ -36,7 +36,8 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. +* Compatible with `asyncio` and `trio` backends. ## Requirements @@ -84,10 +85,9 @@ For a more complete example, see [encode/starlette-example](https://github.com/e ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -167,7 +167,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ diff --git a/docs/index.md b/docs/index.md index 4ae77f0e6..b9692a1fb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,7 +32,7 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. ## Requirements @@ -79,10 +79,9 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following dependencies are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -161,7 +160,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ diff --git a/docs/testclient.md b/docs/testclient.md index 61f7201c6..f37858401 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -31,6 +31,22 @@ application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. +### Selecting the Async backend + +`TestClient.async_backend` is a dictionary which allows you to set the options +for the backend used to run tests. These options are passed to +`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) +for more information about backend options. By default, `asyncio` is used. + +To run `Trio`, set `async_backend["backend"] = "trio"`, for example: + +```python +def test_app() + client = TestClient(app) + client.async_backend["backend"] = "trio" + ... +``` + ### Testing WebSocket sessions You can also test websocket sessions with the test client. @@ -72,6 +88,8 @@ always raised by the test client. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. +`websocket_connect()` must be used as a context manager (in a `with` block). + #### Sending data * `.send_text(data)` - Send the given text to the application. diff --git a/requirements.txt b/requirements.txt index 6ec5bf09e..ae3d91f26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,9 +18,10 @@ types-requests types-contextvars types-aiofiles types-PyYAML +types-dataclasses pytest pytest-cov -pytest-asyncio +trio # Documentation mkdocs diff --git a/setup.py b/setup.py index c48356370..a687ad861 100644 --- a/setup.py +++ b/setup.py @@ -37,9 +37,9 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, + install_requires=["anyio>=3.0.0,<4"], extras_require={ "full": [ - "aiofiles", "graphene", "itsdangerous", "jinja2", diff --git a/starlette/concurrency.py b/starlette/concurrency.py index c8c5d57ac..e89d1e047 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,33 +1,32 @@ -import asyncio import functools -import sys import typing from typing import Any, AsyncGenerator, Iterator +import anyio + try: import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover contextvars = None # type: ignore -if sys.version_info >= (3, 7): # pragma: no cover - from asyncio import create_task -else: # pragma: no cover - from asyncio import ensure_future as create_task T = typing.TypeVar("T") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - tasks = [create_task(handler(**kwargs)) for handler, kwargs in args] - (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - [task.cancel() for task in pending] - [task.result() for task in done] + async with anyio.create_task_group() as task_group: + + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: - loop = asyncio.get_event_loop() if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) @@ -35,9 +34,9 @@ async def run_in_threadpool( func = context.run args = (child,) elif kwargs: # pragma: no cover - # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await loop.run_in_executor(None, func, *args) + return await anyio.to_thread.run_sync(func, *args) class _StopIteration(Exception): @@ -57,6 +56,6 @@ def _next(iterator: Iterator) -> Any: async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: while True: try: - yield await run_in_threadpool(_next, iterator) + yield await anyio.to_thread.run_sync(_next, iterator) except _StopIteration: break diff --git a/starlette/graphql.py b/starlette/graphql.py index ed2274f89..6e5d6ec6a 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -31,29 +31,18 @@ class GraphQLApp: def __init__( self, schema: "graphene.Schema", - executor: typing.Any = None, executor_class: type = None, graphiql: bool = True, ) -> None: self.schema = schema self.graphiql = graphiql - if executor is None: - # New style in 0.10.0. Use 'executor_class'. - # See issue https://github.com/encode/starlette/issues/242 - self.executor = executor - self.executor_class = executor_class - self.is_async = executor_class is not None and issubclass( - executor_class, AsyncioExecutor - ) - else: - # Old style. Use 'executor'. - # We should remove this in the next median/major version bump. - self.executor = executor - self.executor_class = None - self.is_async = isinstance(executor, AsyncioExecutor) + self.executor_class = executor_class + self.is_async = executor_class is not None and issubclass( + executor_class, AsyncioExecutor + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if self.executor is None and self.executor_class is not None: + if self.executor_class is not None: self.executor = self.executor_class() request = Request(scope, receive=receive) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b347a6a2d..77ba66925 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,9 +1,10 @@ -import asyncio import typing +import anyio + from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ @@ -21,45 +22,39 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, self.call_next) - await response(scope, receive, send) + async def call_next(request: Request) -> Response: + send_stream, recv_stream = anyio.create_memory_object_stream() - async def call_next(self, request: Request) -> Response: - loop = asyncio.get_event_loop() - queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue() + async def coro() -> None: + async with send_stream: + await self.app(scope, request.receive, send_stream.send) - scope = request.scope - receive = request.receive - send = queue.put + task_group.start_soon(coro) - async def coro() -> None: try: - await self.app(scope, receive, send) - finally: - await queue.put(None) - - task = loop.create_task(coro()) - message = await queue.get() - if message is None: - task.result() - raise RuntimeError("No response returned.") - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: - message = await queue.get() - if message is None: - break - assert message["type"] == "http.response.body" - yield message.get("body", b"") - task.result() - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response + message = await recv_stream.receive() + except anyio.EndOfStream: + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + yield message.get("body", b"") + + response = StreamingResponse( + status_code=message["status"], content=body_stream() + ) + response.raw_headers = message["headers"] + return response + + async with anyio.create_task_group() as task_group: + request = Request(scope, receive=receive) + response = await self.dispatch_func(request, call_next) + await response(scope, receive, send) + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 515cf3e76..7e69e1a6b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,10 +1,11 @@ -import asyncio import io +import math import sys import typing -from starlette.concurrency import run_in_threadpool -from starlette.types import Message, Receive, Scope, Send +import anyio + +from starlette.types import Receive, Scope, Send def build_environ(scope: Scope, body: bytes) -> dict: @@ -69,9 +70,9 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.send_event = asyncio.Event() - self.send_queue: typing.List[typing.Optional[Message]] = [] - self.loop = asyncio.get_event_loop() + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.response_started = False self.exc_info: typing.Any = None @@ -83,31 +84,18 @@ async def __call__(self, receive: Receive, send: Send) -> None: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - sender = None - try: - sender = self.loop.create_task(self.sender(send)) - await run_in_threadpool(self.wsgi, environ, self.start_response) - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) - finally: - if sender and not sender.done(): - sender.cancel() # pragma: no cover + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: - while True: - if self.send_queue: - message = self.send_queue.pop(0) - if message is None: - return + async with self.stream_receive: + async for message in self.stream_receive: await send(message) - else: - await self.send_event.wait() - self.send_event.clear() def start_response( self, @@ -124,21 +112,22 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - self.send_queue.append( + anyio.from_thread.run( + self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, - } + }, ) - self.loop.call_soon_threadsafe(self.send_event.set) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - self.send_queue.append( - {"type": "http.response.body", "body": chunk, "more_body": True} + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, ) - self.loop.call_soon_threadsafe(self.send_event.set) - self.send_queue.append({"type": "http.response.body", "body": b""}) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.from_thread.run( + self.stream_send.send, {"type": "http.response.body", "body": b""} + ) diff --git a/starlette/requests.py b/starlette/requests.py index ab6f51424..54ed8611e 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,9 +1,10 @@ -import asyncio import json import typing from collections.abc import Mapping from http import cookies as http_cookies +import anyio + from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.formparsers import FormParser, MultiPartParser from starlette.types import Message, Receive, Scope, Send @@ -251,10 +252,12 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: - try: - message = await asyncio.wait_for(self._receive(), timeout=0.0000001) - except asyncio.TimeoutError: - message = {} + message: Message = {} + + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() + message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True diff --git a/starlette/responses.py b/starlette/responses.py index 00f6be4db..d03df2329 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -6,24 +6,20 @@ import sys import typing from email.utils import formatdate +from functools import partial from mimetypes import guess_type as mimetypes_guess_type from urllib.parse import quote +import anyio + from starlette.background import BackgroundTask -from starlette.concurrency import iterate_in_threadpool, run_until_first_complete +from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send # Workaround for adding samesite support to pre 3.8 python http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore -try: - import aiofiles - from aiofiles.os import stat as aio_stat -except ImportError: # pragma: nocover - aiofiles = None # type: ignore - aio_stat = None # type: ignore - # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await run_until_first_complete( - (self.stream_response, {"send": send}), - (self.listen_for_disconnect, {"receive": receive}), - ) + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() @@ -244,7 +244,6 @@ def __init__( stat_result: os.stat_result = None, method: str = None, ) -> None: - assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse" self.path = path self.status_code = status_code self.filename = filename @@ -280,7 +279,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.stat_result is None: try: - stat_result = await aio_stat(self.path) + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") @@ -298,10 +297,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - # Tentatively ignoring type checking failure to work around the wrong type - # definitions for aiofile that come with typeshed. See - # https://github.com/python/typeshed/pull/4650 - async with aiofiles.open(self.path, mode="rb") as file: # type: ignore + async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 15a67fe35..33ea0b033 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,7 +4,7 @@ import typing from email.utils import parsedate -from aiofiles.os import stat as aio_stat +import anyio from starlette.datastructures import URL, Headers from starlette.responses import ( @@ -154,7 +154,7 @@ async def lookup_path( # directory. continue try: - stat_result = await aio_stat(full_path) + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) return full_path, stat_result except FileNotFoundError: pass @@ -187,7 +187,7 @@ async def check_config(self) -> None: return try: - stat_result = await aio_stat(self.directory) + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." diff --git a/starlette/testclient.py b/starlette/testclient.py index 77c038b17..c1c0fe165 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,15 +1,19 @@ import asyncio +import contextlib import http import inspect import io import json +import math import queue -import threading import types import typing +from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit +import anyio import requests +from anyio.streams.stapled import StapledObjectStream from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -89,11 +93,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( - self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = "" + self, + app: ASGI3App, + async_backend: typing.Dict[str, typing.Any], + raise_server_exceptions: bool = True, + root_path: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path + self.async_backend = async_backend def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any @@ -142,7 +151,7 @@ def send( "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope) + session = WebSocketTestSession(self.app, scope, self.async_backend) raise _Upgrade(session) scope = { @@ -161,17 +170,17 @@ def send( request_complete = False response_started = False - response_complete = False + response_complete: anyio.Event raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} template = None context = None async def receive() -> Message: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: - while not response_complete: - await asyncio.sleep(0.0001) + if not response_complete.is_set(): + await response_complete.wait() return {"type": "http.disconnect"} body = request.body @@ -195,7 +204,7 @@ async def receive() -> Message: return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( @@ -217,7 +226,7 @@ async def send(message: Message) -> None: response_started ), 'Received "http.response.body" without "http.response.start".' assert ( - not response_complete + not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) @@ -225,19 +234,15 @@ async def send(message: Message) -> None: raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) - response_complete = True + response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(self.app(scope, receive, send)) + with anyio.start_blocking_portal(**self.async_backend) as portal: + response_complete = portal.call(anyio.Event) + portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc @@ -264,48 +269,59 @@ async def send(message: Message) -> None: class WebSocketTestSession: - def __init__(self, app: ASGI3App, scope: Scope) -> None: + def __init__( + self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any] + ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None + self.async_backend = async_backend self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._thread = threading.Thread(target=self._run) - self.send({"type": "websocket.connect"}) - self._thread.start() - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) return self def __exit__(self, *args: typing.Any) -> None: - self.close(1000) - self._thread.join() + try: + self.close(1000) + finally: + self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): raise message - def _run(self) -> None: + async def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ - loop = asyncio.new_event_loop() scope = self.scope receive = self._asgi_receive send = self._asgi_send try: - loop.run_until_complete(self.app(scope, receive, send)) + await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) - finally: - loop.close() + raise async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): - await asyncio.sleep(0) + await anyio.sleep(0) return self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: @@ -365,6 +381,14 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. + #: These options are passed to `anyio.start_blocking_portal()` + async_backend: typing.Dict[str, typing.Any] = { + "backend": "asyncio", + "backend_options": {}, + } + + task: "Future[None]" + def __init__( self, app: typing.Union[ASGI2App, ASGI3App], @@ -381,6 +405,7 @@ def __init__( asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, + self.async_backend, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) @@ -452,27 +477,40 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - loop = asyncio.get_event_loop() - self.send_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue() - self.receive_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue() - self.task = loop.create_task(self.lifespan()) - loop.run_until_complete(self.wait_startup()) + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + try: + self.task = self.portal.start_task_soon(self.lifespan) + self.portal.call(self.wait_startup) + except Exception: + self.exit_stack.close() + raise return self def __exit__(self, *args: typing.Any) -> None: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wait_shutdown()) + try: + self.portal.call(self.wait_shutdown) + finally: + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} try: - await self.app(scope, self.receive_queue.get, self.send_queue.put) + await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: - await self.send_queue.put(None) + await self.stream_send.send(None) async def wait_startup(self) -> None: - await self.receive_queue.put({"type": "lifespan.startup"}) - message = await self.send_queue.get() + await self.stream_receive.send({"type": "lifespan.startup"}) + message = await self.stream_send.receive() if message is None: self.task.result() assert message["type"] in ( @@ -480,14 +518,14 @@ async def wait_startup(self) -> None: "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - message = await self.send_queue.get() + message = await self.stream_send.receive() if message is None: self.task.result() async def wait_shutdown(self) -> None: - await self.receive_queue.put({"type": "lifespan.shutdown"}) - message = await self.send_queue.get() - if message is None: - self.task.result() - assert message["type"] == "lifespan.shutdown.complete" - await self.task + async with self.stream_send: + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await self.stream_send.receive() + if message is None: + self.task.result() + assert message["type"] == "lifespan.shutdown.complete" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..d1f3ba8e4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest + +from starlette.testclient import TestClient + + +@pytest.fixture( + params=[ + pytest.param( + {"backend": "asyncio", "backend_options": {"use_uvloop": False}}, + id="asyncio", + ), + pytest.param({"backend": "trio", "backend_options": {}}, id="trio"), + ], + autouse=True, +) +def anyio_backend(request, monkeypatch): + monkeypatch.setattr(TestClient, "async_backend", request.param) + return request.param["backend"] + + +@pytest.fixture +def no_trio_support(request): + if request.keywords.get("trio"): + pytest.skip("Trio not supported (yet!)") diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 048dd9ffb..df8901934 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -143,3 +143,18 @@ def homepage(request): def test_middleware_repr(): middleware = Middleware(CustomMiddleware) assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_fully_evaluated_response(): + # Test for https://github.com/encode/starlette/issues/1022 + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + app = Starlette() + app.add_middleware(CustomMiddleware) + + client = TestClient(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index c178ef9da..28b2a7ba3 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -67,4 +67,5 @@ async def app(scope, receive, send): with pytest.raises(RuntimeError): client = TestClient(app) - client.websocket_connect("/") + with client.websocket_connect("/"): + pass # pragma: nocover diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3373f67c5..8ee87932a 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -261,10 +261,14 @@ def test_authentication_required(): def test_websocket_authentication_required(): with TestClient(app) as client: with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws") + with client.websocket_connect("/ws"): + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}) + with client.websocket_connect( + "/ws", headers={"Authorization": "basic foobar"} + ): + pass # pragma: nocover with client.websocket_connect( "/ws", auth=("tomchristie", "example") @@ -273,12 +277,14 @@ def test_websocket_authentication_required(): assert data == {"authenticated": True, "user": "tomchristie"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws/decorated") + with client.websocket_connect("/ws/decorated"): + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - client.websocket_connect( + with client.websocket_connect( "/ws/decorated", headers={"Authorization": "basic foobar"} - ) + ): + pass # pragma: nocover with client.websocket_connect( "/ws/decorated", auth=("tomchristie", "example") diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 000000000..cc5eba974 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,22 @@ +import anyio +import pytest + +from starlette.concurrency import run_until_first_complete + + +@pytest.mark.anyio +async def test_run_until_first_complete(): + task1_finished = anyio.Event() + task2_finished = anyio.Event() + + async def task1(): + task1_finished.set() + + async def task2(): + await task1_finished.wait() + await anyio.sleep(0) # pragma: nocover + task2_finished.set() # pragma: nocover + + await run_until_first_complete((task1, {}), (task2, {})) + assert task1_finished.is_set() + assert not task2_finished.is_set() diff --git a/tests/test_database.py b/tests/test_database.py index 258a71ec5..f7280c2c7 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -19,6 +19,9 @@ ) +pytestmark = pytest.mark.usefixtures("no_trio_support") + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index b0e6baf98..bb71ba870 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -217,7 +217,7 @@ class BigUploadFile(UploadFile): spool_max_size = 1024 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_upload_file(): big_file = BigUploadFile("big-file") await big_file.write(b"big-data" * 512) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 841c9a5cf..bab6961b5 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -54,7 +54,8 @@ def test_not_modified(): def test_websockets_should_raise(): with pytest.raises(RuntimeError): - client.websocket_connect("/runtime_error") + with client.websocket_connect("/runtime_error"): + pass # pragma: nocover def test_handled_exc_after_response(): diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 67f307231..b945a5cfe 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,5 +1,4 @@ import graphene -import pytest from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette @@ -142,27 +141,8 @@ async def resolve_hello(self, info, name): async_app = GraphQLApp(schema=async_schema, executor_class=AsyncioExecutor) -def test_graphql_async(): +def test_graphql_async(no_trio_support): client = TestClient(async_app) response = client.get("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} - - -async_schema = graphene.Schema(query=ASyncQuery) - - -@pytest.fixture -def old_style_async_app(event_loop) -> GraphQLApp: - old_style_async_app = GraphQLApp( - schema=async_schema, executor=AsyncioExecutor(loop=event_loop) - ) - return old_style_async_app - - -def test_graphql_async_old_style_executor(old_style_async_app: GraphQLApp): - # See https://github.com/encode/starlette/issues/242 - client = TestClient(old_style_async_app) - response = client.get("/?query={ hello }") - assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}} diff --git a/tests/test_requests.py b/tests/test_requests.py index a83a2c480..fee059ab2 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,5 +1,4 @@ -import asyncio - +import anyio import pytest from starlette.requests import ClientDisconnect, Request, State @@ -212,9 +211,8 @@ async def receiver(): return {"type": "http.disconnect"} scope = {"type": "http", "method": "POST", "path": "/"} - loop = asyncio.get_event_loop() with pytest.raises(ClientDisconnect): - loop.run_until_complete(app(scope, receiver, None)) + anyio.run(app, scope, receiver, None) def test_request_is_disconnected(): diff --git a/tests/test_responses.py b/tests/test_responses.py index fd2ba0e42..496e64c86 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,6 +1,6 @@ -import asyncio import os +import anyio import pytest from starlette import status @@ -83,7 +83,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -197,7 +197,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task diff --git a/tests/test_routing.py b/tests/test_routing.py index fff3332db..1d8eb8d95 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -286,7 +286,8 @@ def test_protocol_switch(): assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/404") + with client.websocket_connect("/404"): + pass # pragma: nocover ok = PlainTextResponse("OK") @@ -492,7 +493,8 @@ def test_standalone_ws_route_does_not_match(): app = WebSocketRoute("/", ws_helloworld) client = TestClient(app) with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/invalid") + with client.websocket_connect("/invalid"): + pass # pragma: nocover def test_lifespan_async(): diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 6b325071f..3c8ff240e 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,8 +1,8 @@ -import asyncio import os import pathlib import time +import anyio import pytest from starlette.applications import Starlette @@ -153,8 +153,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): # We can't test this with 'requests', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} - loop = asyncio.get_event_loop() - response = loop.run_until_complete(app.get_response(path, scope)) + response = anyio.run(app.get_response, path, scope) assert response.status_code == 404 assert response.body == b"Not Found" diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 00f4e0125..86f36e172 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,5 +1,4 @@ -import asyncio - +import anyio import pytest from starlette.applications import Starlette @@ -118,13 +117,14 @@ async def respond(websocket): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - asyncio.ensure_future(respond(websocket)) - try: - # this will block as the client does not send us data - # it should not prevent `respond` from executing though - await websocket.receive_json() - except WebSocketDisconnect: - pass + async with anyio.create_task_group() as task_group: + task_group.start_soon(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass return asgi diff --git a/tests/test_websockets.py b/tests/test_websockets.py index ffb1a44a8..63ecd050a 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,9 +1,7 @@ -import asyncio - +import anyio import pytest from starlette import status -from starlette.concurrency import run_until_first_complete from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -208,23 +206,24 @@ async def asgi(receive, send): def test_websocket_concurrency_pattern(): def app(scope): - async def reader(websocket, queue): - async for data in websocket.iter_json(): - await queue.put(data) + stream_send, stream_receive = anyio.create_memory_object_stream() - async def writer(websocket, queue): - while True: - message = await queue.get() - await websocket.send_json(message) + async def reader(websocket): + async with stream_send: + async for data in websocket.iter_json(): + await stream_send.send(data) + + async def writer(websocket): + async with stream_receive: + async for message in stream_receive: + await websocket.send_json(message) async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) - queue = asyncio.Queue() await websocket.accept() - await run_until_first_complete( - (reader, {"websocket": websocket, "queue": queue}), - (writer, {"websocket": websocket, "queue": queue}), - ) + async with anyio.create_task_group() as task_group: + task_group.start_soon(reader, websocket) + await writer(websocket) await websocket.close() return asgi @@ -283,7 +282,8 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: - client.websocket_connect("/") + with client.websocket_connect("/"): + pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY @@ -311,7 +311,8 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(AssertionError): - client.websocket_connect("/123?a=abc") + with client.websocket_connect("/123?a=abc"): + pass # pragma: nocover def test_duplicate_close(): @@ -327,7 +328,7 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass + pass # pragma: nocover def test_duplicate_disconnect():