From d06f40c6fad37a6ec0311f1f44595ad49db8efcd Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 26 Mar 2021 18:01:06 -0500 Subject: [PATCH 01/59] First whack at anyio integration --- README.md | 7 ++-- setup.py | 1 + starlette/applications.py | 6 +++- starlette/concurrency.py | 29 +++++++++-------- starlette/middleware/base.py | 39 +++++++++++------------ starlette/middleware/wsgi.py | 62 +++++++++++++++++++----------------- starlette/requests.py | 10 +++--- 7 files changed, 82 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 0b4508124..1e3d1cc55 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, -which is ideal for building high performance asyncio services. +which is ideal for building high performance async services. It is production-ready, and gives you the following: @@ -38,7 +38,8 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. +* Compatible with `asyncio and `trio` backends. ## Requirements @@ -86,7 +87,7 @@ For a more complete example, see [encode/starlette-example](https://github.com/e ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. * [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. diff --git a/setup.py b/setup.py index 37a09ea48..29e60d017 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, + instal_requires=["anyio<3,>=2"], extras_require={ "full": [ "aiofiles", diff --git a/starlette/applications.py b/starlette/applications.py index 34c3e38bd..10bd075a3 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,5 +1,7 @@ import typing +import anyio + from starlette.datastructures import State, URLPath from starlette.exceptions import ExceptionMiddleware from starlette.middleware import Middleware @@ -109,7 +111,9 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self - await self.middleware_stack(scope, receive, send) + task_group = scope["task_group"] = anyio.create_task_group() + async with task_group: + await self.middleware_stack(scope, receive, send) # The following usages are now discouraged in favour of configuration #  during Starlette.__init__(...) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index c8c5d57ac..14ee46ab3 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,33 +1,36 @@ -import asyncio import functools import sys import typing from typing import Any, AsyncGenerator, Iterator +import anyio + try: import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover contextvars = None # type: ignore -if sys.version_info >= (3, 7): # pragma: no cover - from asyncio import create_task -else: # pragma: no cover - from asyncio import ensure_future as create_task T = typing.TypeVar("T") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - tasks = [create_task(handler(**kwargs)) for handler, kwargs in args] - (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - [task.cancel() for task in pending] - [task.result() for task in done] + result: Any = None + async with anyio.create_task_group() as task_group: + + async def task(_handler, _kwargs) -> Any: + nonlocal result + result = await _handler(**_kwargs) + await task_group.cancel_scope.cancel() + + for handler, kwargs in args: + await task_group.spawn(task, handler, kwargs) + return result async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: - loop = asyncio.get_event_loop() if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) @@ -35,9 +38,9 @@ async def run_in_threadpool( func = context.run args = (child,) elif kwargs: # pragma: no cover - # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await loop.run_in_executor(None, func, *args) + return await anyio.run_sync_in_worker_thread(func, *args) class _StopIteration(Exception): @@ -57,6 +60,6 @@ def _next(iterator: Iterator) -> Any: async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: while True: try: - yield await run_in_threadpool(_next, iterator) + yield await anyio.run_sync_in_worker_thread(_next, iterator) except _StopIteration: break diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b347a6a2d..a436ac288 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,7 @@ -import asyncio import typing +import anyio + from starlette.requests import Request from starlette.responses import Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -26,34 +27,30 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) async def call_next(self, request: Request) -> Response: - loop = asyncio.get_event_loop() - queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue() + send_stream, recv_stream = anyio.create_memory_object_stream(0, item_type=Message) # XXX size scope = request.scope - receive = request.receive - send = queue.put + task_group = scope["task_group"] async def coro() -> None: - try: - await self.app(scope, receive, send) - finally: - await queue.put(None) - - task = loop.create_task(coro()) - message = await queue.get() - if message is None: - task.result() + async with send_stream: + await self.app(scope, recv_stream.receive, send_stream.send) + + await task_group.spawn(coro) + + try: + message = await recv_stream.receive() + except anyio.EndOfStream: + print("WHAT " * 10) raise RuntimeError("No response returned.") + assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: - message = await queue.get() - if message is None: - break - assert message["type"] == "http.response.body" - yield message.get("body", b"") - task.result() + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + yield message.get("body", b"") response = StreamingResponse( status_code=message["status"], content=body_stream() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index a552b4180..585ffba82 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,8 +1,9 @@ -import asyncio import io import sys import typing +import anyio + from starlette.concurrency import run_in_threadpool from starlette.types import Message, Receive, Scope, Send @@ -69,9 +70,8 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.send_event = asyncio.Event() + self.send_event = anyio.create_event() self.send_queue = [] # type: typing.List[typing.Optional[Message]] - self.loop = asyncio.get_event_loop() self.response_started = False self.exc_info = None # type: typing.Any @@ -83,31 +83,35 @@ async def __call__(self, receive: Receive, send: Send) -> None: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - sender = None + + async with anyio.create_task_group() as task_group: + sender_finished = anyio.create_event() + try: + await task_group.spawn(self.sender, send, sender_finished) + await anyio.run_sync_in_worker_thread(self.wsgi, environ, self.start_response) + self.send_queue.append(None) + await self.send_event.set() + await sender_finished.wait() + if self.exc_info is not None: + raise self.exc_info[0].with_traceback( + self.exc_info[1], self.exc_info[2] + ) + finally: + await task_group.cancel_scope.cancel() + + async def sender(self, send: Send, finished) -> None: try: - sender = self.loop.create_task(self.sender(send)) - await run_in_threadpool(self.wsgi, environ, self.start_response) - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) + while True: + if self.send_queue: + message = self.send_queue.pop(0) + if message is None: + return + await send(message) + else: + await self.send_event.wait() + self.send_event = anyio.create_event() finally: - if sender and not sender.done(): - sender.cancel() # pragma: no cover - - async def sender(self, send: Send) -> None: - while True: - if self.send_queue: - message = self.send_queue.pop(0) - if message is None: - return - await send(message) - else: - await self.send_event.wait() - self.send_event.clear() + await finished.set() def start_response( self, @@ -131,14 +135,14 @@ def start_response( "headers": headers, } ) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.run_async_from_thread(self.send_event.set) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): self.send_queue.append( {"type": "http.response.body", "body": chunk, "more_body": True} ) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.run_async_from_thread(self.send_event.set) self.send_queue.append({"type": "http.response.body", "body": b""}) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.run_async_from_thread(self.send_event.set) diff --git a/starlette/requests.py b/starlette/requests.py index ab6f51424..8b34775bc 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,9 +1,10 @@ -import asyncio import json import typing from collections.abc import Mapping from http import cookies as http_cookies +import anyio + from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.formparsers import FormParser, MultiPartParser from starlette.types import Message, Receive, Scope, Send @@ -251,10 +252,9 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: - try: - message = await asyncio.wait_for(self._receive(), timeout=0.0000001) - except asyncio.TimeoutError: - message = {} + message = {} + async with anyio.move_on_after(0.0000001): + message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True From 75310b5f7fd55739ffe1ca4654a59a03efcea69f Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 26 Mar 2021 18:12:08 -0500 Subject: [PATCH 02/59] Fix formatting --- starlette/concurrency.py | 1 + starlette/middleware/base.py | 4 +++- starlette/middleware/wsgi.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 14ee46ab3..2209234cd 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -28,6 +28,7 @@ async def task(_handler, _kwargs) -> Any: return result + async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index a436ac288..46a71a34e 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -27,7 +27,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) async def call_next(self, request: Request) -> Response: - send_stream, recv_stream = anyio.create_memory_object_stream(0, item_type=Message) # XXX size + send_stream, recv_stream = anyio.create_memory_object_stream( + 0, item_type=Message + ) # XXX size scope = request.scope task_group = scope["task_group"] diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 585ffba82..94111e60b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -88,7 +88,9 @@ async def __call__(self, receive: Receive, send: Send) -> None: sender_finished = anyio.create_event() try: await task_group.spawn(self.sender, send, sender_finished) - await anyio.run_sync_in_worker_thread(self.wsgi, environ, self.start_response) + await anyio.run_sync_in_worker_thread( + self.wsgi, environ, self.start_response + ) self.send_queue.append(None) await self.send_event.set() await sender_finished.wait() From a66068496da91d2c32d27c45eb05181a8c28d0b7 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 26 Mar 2021 18:13:35 -0500 Subject: [PATCH 03/59] Remove debug messages --- starlette/middleware/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 46a71a34e..f28b8c969 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -29,7 +29,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(self, request: Request) -> Response: send_stream, recv_stream = anyio.create_memory_object_stream( 0, item_type=Message - ) # XXX size + ) scope = request.scope task_group = scope["task_group"] @@ -43,7 +43,6 @@ async def coro() -> None: try: message = await recv_stream.receive() except anyio.EndOfStream: - print("WHAT " * 10) raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" From 42b83cbae5111301e1a19de2c59a258ae47b44a6 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 26 Mar 2021 18:18:45 -0500 Subject: [PATCH 04/59] mypy fixes --- starlette/concurrency.py | 2 +- starlette/middleware/base.py | 5 +---- starlette/middleware/wsgi.py | 2 +- starlette/requests.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 2209234cd..de6b48f58 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -18,7 +18,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) - result: Any = None async with anyio.create_task_group() as task_group: - async def task(_handler, _kwargs) -> Any: + async def task(_handler: typing.Callable, _kwargs: dict) -> Any: nonlocal result result = await _handler(**_kwargs) await task_group.cancel_scope.cancel() diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f28b8c969..0e58d5bd5 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -27,10 +27,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) async def call_next(self, request: Request) -> Response: - send_stream, recv_stream = anyio.create_memory_object_stream( - 0, item_type=Message - ) - + send_stream, recv_stream = anyio.create_memory_object_stream() scope = request.scope task_group = scope["task_group"] diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 94111e60b..0c8d44873 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -101,7 +101,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: finally: await task_group.cancel_scope.cancel() - async def sender(self, send: Send, finished) -> None: + async def sender(self, send: Send, finished: anyio.abc.Event) -> None: try: while True: if self.send_queue: diff --git a/starlette/requests.py b/starlette/requests.py index 8b34775bc..fbd456fef 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -252,7 +252,7 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: - message = {} + message: Message = {} async with anyio.move_on_after(0.0000001): message = await self._receive() From 9870a1fa6a8a4799c9fb3c753391cb492cf19f6f Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 26 Mar 2021 20:44:57 -0500 Subject: [PATCH 05/59] Update README.md Co-authored-by: Marcelo Trylesinski --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1e3d1cc55..bbd7a0a3e 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ It is production-ready, and gives you the following: * 100% test coverage. * 100% type annotated codebase. * Few hard dependencies. -* Compatible with `asyncio and `trio` backends. +* Compatible with `asyncio` and `trio` backends. ## Requirements From 6997eb95a16463aed311e44be8e88560b149cbe8 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 09:35:30 -0500 Subject: [PATCH 06/59] Fix install_requires typo --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 29e60d017..cdbc799d6 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, - instal_requires=["anyio<3,>=2"], + install_requires=["anyio<3,>=2"], extras_require={ "full": [ "aiofiles", From e1c2adb69f80ef1238afd8bd1af98126e76ceeb4 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 09:35:51 -0500 Subject: [PATCH 07/59] move_on_after blocks if deadline is too small --- starlette/requests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/requests.py b/starlette/requests.py index fbd456fef..daa33ecf5 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -253,7 +253,7 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} - async with anyio.move_on_after(0.0000001): + async with anyio.move_on_after(0.001): # XXX: to small of a deadline and this blocks message = await self._receive() if message.get("type") == "http.disconnect": From de84b4a58cefc0dc9b9143665096f0930ad3ffad Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 09:36:03 -0500 Subject: [PATCH 08/59] Linter fixes --- starlette/concurrency.py | 1 - starlette/middleware/base.py | 2 +- starlette/middleware/wsgi.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index de6b48f58..dd3f280cd 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,5 +1,4 @@ import functools -import sys import typing from typing import Any, AsyncGenerator, Iterator diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 0e58d5bd5..4241ad718 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 0c8d44873..2e004c7fd 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -4,7 +4,6 @@ import anyio -from starlette.concurrency import run_in_threadpool from starlette.types import Message, Receive, Scope, Send From e91ec33bc023319a455f4f3bb06754065b86c8aa Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 09:51:11 -0500 Subject: [PATCH 09/59] Improve WSGI structured concurrency --- starlette/middleware/wsgi.py | 52 ++++++++++++++---------------------- starlette/requests.py | 4 ++- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 2e004c7fd..dccf9b248 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -4,7 +4,7 @@ import anyio -from starlette.types import Message, Receive, Scope, Send +from starlette.types import Receive, Scope, Send def build_environ(scope: Scope, body: bytes) -> dict: @@ -69,8 +69,7 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.send_event = anyio.create_event() - self.send_queue = [] # type: typing.List[typing.Optional[Message]] + self.stream_send, self.stream_receive = anyio.create_memory_object_stream() self.response_started = False self.exc_info = None # type: typing.Any @@ -84,15 +83,12 @@ async def __call__(self, receive: Receive, send: Send) -> None: environ = build_environ(self.scope, body) async with anyio.create_task_group() as task_group: - sender_finished = anyio.create_event() try: - await task_group.spawn(self.sender, send, sender_finished) - await anyio.run_sync_in_worker_thread( - self.wsgi, environ, self.start_response - ) - self.send_queue.append(None) - await self.send_event.set() - await sender_finished.wait() + await task_group.spawn(self.sender, send) + async with self.stream_send: + await anyio.run_sync_in_worker_thread( + self.wsgi, environ, self.start_response + ) if self.exc_info is not None: raise self.exc_info[0].with_traceback( self.exc_info[1], self.exc_info[2] @@ -100,19 +96,10 @@ async def __call__(self, receive: Receive, send: Send) -> None: finally: await task_group.cancel_scope.cancel() - async def sender(self, send: Send, finished: anyio.abc.Event) -> None: - try: - while True: - if self.send_queue: - message = self.send_queue.pop(0) - if message is None: - return - await send(message) - else: - await self.send_event.wait() - self.send_event = anyio.create_event() - finally: - await finished.set() + async def sender(self, send: Send) -> None: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) def start_response( self, @@ -129,21 +116,22 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - self.send_queue.append( + anyio.run_async_from_thread( + self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, - } + }, ) - anyio.run_async_from_thread(self.send_event.set) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - self.send_queue.append( - {"type": "http.response.body", "body": chunk, "more_body": True} + anyio.run_async_from_thread( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, ) - anyio.run_async_from_thread(self.send_event.set) - self.send_queue.append({"type": "http.response.body", "body": b""}) - anyio.run_async_from_thread(self.send_event.set) + anyio.run_async_from_thread( + self.stream_send.send, {"type": "http.response.body", "body": b""} + ) diff --git a/starlette/requests.py b/starlette/requests.py index daa33ecf5..206c8b559 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -253,7 +253,9 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} - async with anyio.move_on_after(0.001): # XXX: to small of a deadline and this blocks + async with anyio.move_on_after( + 0.001 + ): # XXX: to small of a deadline and this blocks message = await self._receive() if message.get("type") == "http.disconnect": From 7e2cd461c5b1241d11121a0c8bc9bc1cda7859b1 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 12:15:25 -0500 Subject: [PATCH 10/59] Tests use anyio --- tests/test_graphql.py | 1 + tests/test_requests.py | 6 ++---- tests/test_responses.py | 6 +++--- tests/test_staticfiles.py | 5 ++--- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 67f307231..bc6378250 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -160,6 +160,7 @@ def old_style_async_app(event_loop) -> GraphQLApp: return old_style_async_app +@pytest.mark.skip("XXX") def test_graphql_async_old_style_executor(old_style_async_app: GraphQLApp): # See https://github.com/encode/starlette/issues/242 client = TestClient(old_style_async_app) diff --git a/tests/test_requests.py b/tests/test_requests.py index a83a2c480..fee059ab2 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,5 +1,4 @@ -import asyncio - +import anyio import pytest from starlette.requests import ClientDisconnect, Request, State @@ -212,9 +211,8 @@ async def receiver(): return {"type": "http.disconnect"} scope = {"type": "http", "method": "POST", "path": "/"} - loop = asyncio.get_event_loop() with pytest.raises(ClientDisconnect): - loop.run_until_complete(app(scope, receiver, None)) + anyio.run(app, scope, receiver, None) def test_request_is_disconnected(): diff --git a/tests/test_responses.py b/tests/test_responses.py index 10fbe673c..5610ac9d6 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,6 +1,6 @@ -import asyncio import os +import anyio import pytest from starlette import status @@ -69,7 +69,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -183,7 +183,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 6b325071f..3c8ff240e 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,8 +1,8 @@ -import asyncio import os import pathlib import time +import anyio import pytest from starlette.applications import Starlette @@ -153,8 +153,7 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): # We can't test this with 'requests', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} - loop = asyncio.get_event_loop() - response = loop.run_until_complete(app.get_response(path, scope)) + response = anyio.run(app.get_response, path, scope) assert response.status_code == 404 assert response.body == b"Not Found" From 03e312ede7858c4031a5e01728e7513855d54763 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 12:15:44 -0500 Subject: [PATCH 11/59] Checkin progress on testclient --- starlette/testclient.py | 86 ++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 39 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 1d5e90dc8..51dce3ced 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -4,11 +4,11 @@ import io import json import queue -import threading import types import typing from urllib.parse import unquote, urljoin, urlsplit +import anyio import requests from starlette.types import Message, Receive, Scope, Send @@ -171,7 +171,7 @@ async def receive() -> Message: if request_complete: while not response_complete: - await asyncio.sleep(0.0001) + await anyio.sleep(0.0001) return {"type": "http.disconnect"} body = request.body @@ -231,13 +231,8 @@ async def send(message: Message) -> None: context = message["context"] try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(self.app(scope, receive, send)) + with anyio.start_blocking_portal(backend_options={"debug": True}) as portal: + portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc from None @@ -268,11 +263,11 @@ def __init__(self, app: ASGI3App, scope: Scope) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None + self.portal = anyio.start_blocking_portal(backend_options={"debug": True}) self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue - self._thread = threading.Thread(target=self._run) + self.portal.spawn_task(self._run) self.send({"type": "websocket.connect"}) - self._thread.start() message = self.receive() self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) @@ -281,31 +276,30 @@ def __enter__(self) -> "WebSocketTestSession": return self def __exit__(self, *args: typing.Any) -> None: - self.close(1000) - self._thread.join() + try: + self.close(1000) + finally: + self.portal.stop_from_external_thread() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): raise message - def _run(self) -> None: + async def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ - loop = asyncio.new_event_loop() scope = self.scope receive = self._asgi_receive send = self._asgi_send try: - loop.run_until_complete(self.app(scope, receive, send)) + await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) - finally: - loop.close() async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): - await asyncio.sleep(0) + await anyio.sleep(0) return self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: @@ -452,42 +446,56 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - loop = asyncio.get_event_loop() - self.send_queue = asyncio.Queue() # type: asyncio.Queue - self.receive_queue = asyncio.Queue() # type: asyncio.Queue - self.task = loop.create_task(self.lifespan()) - loop.run_until_complete(self.wait_startup()) + self.stream_send, self.stream_receive = anyio.create_memory_object_stream() + self.portal = anyio.start_blocking_portal( + backend_options={"debug": True} + ) # XXX backend + self.task = self.portal.spawn_task(self.lifespan) + self.portal.call(self.wait_startup) return self def __exit__(self, *args: typing.Any) -> None: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wait_shutdown()) + try: + self.portal.call(self.wait_shutdown) + finally: + self.portal.stop_from_external_thread() async def lifespan(self) -> None: scope = {"type": "lifespan"} - try: - await self.app(scope, self.receive_queue.get, self.send_queue.put) - finally: - await self.send_queue.put(None) + async with self.stream_send: + await self.app(scope, self.stream_receive.receive, self.stream_send.send) async def wait_startup(self) -> None: - await self.receive_queue.put({"type": "lifespan.startup"}) - message = await self.send_queue.get() - if message is None: + try: + await self.stream_send.send({"type": "lifespan.startup"}) + except anyio.ClosedResourceError: + self.task.result() + return + try: + message = await self.stream_receive.receive() + except anyio.EndOfStream: self.task.result() + return assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - message = await self.send_queue.get() - if message is None: + try: + message = await self.stream_receive.receive() + except anyio.EndOfStream: self.task.result() async def wait_shutdown(self) -> None: - await self.receive_queue.put({"type": "lifespan.shutdown"}) - message = await self.send_queue.get() - if message is None: + try: + await self.stream_send.send({"type": "lifespan.shutdown"}) + except anyio.ClosedResourceError: + self.task.result() + return + try: + message = await self.stream_receive.receive() + except anyio.EndOfStream: self.task.result() + return assert message["type"] == "lifespan.shutdown.complete" - await self.task + self.task.result() From fd4569eeaf6b81b0653ae31e48c808e9df057f5c Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 15:51:41 -0500 Subject: [PATCH 12/59] Prep for anyio 3 --- starlette/concurrency.py | 4 +-- starlette/middleware/base.py | 12 +++++-- starlette/middleware/wsgi.py | 19 ++++++----- starlette/requests.py | 5 ++- starlette/testclient.py | 61 ++++++++++++++++++------------------ 5 files changed, 54 insertions(+), 47 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index dd3f280cd..645b96254 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -20,10 +20,10 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) - async def task(_handler: typing.Callable, _kwargs: dict) -> Any: nonlocal result result = await _handler(**_kwargs) - await task_group.cancel_scope.cancel() + task_group.cancel_scope.cancel() for handler, kwargs in args: - await task_group.spawn(task, handler, kwargs) + task_group.spawn(task, handler, kwargs) return result diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 4241ad718..9d56c37ec 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -31,11 +31,17 @@ async def call_next(self, request: Request) -> Response: scope = request.scope task_group = scope["task_group"] + coro_exc = None + async def coro() -> None: async with send_stream: - await self.app(scope, recv_stream.receive, send_stream.send) + try: + await self.app(scope, recv_stream.receive, send_stream.send) + except BaseException as exc: + nonlocal coro_exc + coro_exc = exc - await task_group.spawn(coro) + task_group.spawn(coro) try: message = await recv_stream.receive() @@ -49,6 +55,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: async for message in recv_stream: assert message["type"] == "http.response.body" yield message.get("body", b"") + if coro_exc is not None: + raise coro_exc response = StreamingResponse( status_code=message["status"], content=body_stream() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index dccf9b248..a459f8aff 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,4 +1,5 @@ import io +import math import sys import typing @@ -69,7 +70,9 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.stream_send, self.stream_receive = anyio.create_memory_object_stream() + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.response_started = False self.exc_info = None # type: typing.Any @@ -84,7 +87,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: try: - await task_group.spawn(self.sender, send) + task_group.spawn(self.sender, send) async with self.stream_send: await anyio.run_sync_in_worker_thread( self.wsgi, environ, self.start_response @@ -94,7 +97,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: self.exc_info[1], self.exc_info[2] ) finally: - await task_group.cancel_scope.cancel() + task_group.cancel_scope.cancel() async def sender(self, send: Send) -> None: async with self.stream_receive: @@ -116,8 +119,7 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - anyio.run_async_from_thread( - self.stream_send.send, + self.stream_send.send_nowait( { "type": "http.response.start", "status": status_code, @@ -127,11 +129,8 @@ def start_response( def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - anyio.run_async_from_thread( - self.stream_send.send, + self.stream_send.send_nowait( {"type": "http.response.body", "body": chunk, "more_body": True}, ) - anyio.run_async_from_thread( - self.stream_send.send, {"type": "http.response.body", "body": b""} - ) + self.stream_send.send_nowait({"type": "http.response.body", "body": b""}) diff --git a/starlette/requests.py b/starlette/requests.py index 206c8b559..b71bdd1c1 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -253,9 +253,8 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} - async with anyio.move_on_after( - 0.001 - ): # XXX: to small of a deadline and this blocks + with anyio.move_on_after(0.001): + # XXX: to small of a deadline and this blocks message = await self._receive() if message.get("type") == "http.disconnect": diff --git a/starlette/testclient.py b/starlette/testclient.py index 51dce3ced..dd199ac1d 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,8 +1,10 @@ import asyncio +import contextlib import http import inspect import io import json +import math import queue import types import typing @@ -263,7 +265,10 @@ def __init__(self, app: ASGI3App, scope: Scope) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None - self.portal = anyio.start_blocking_portal(backend_options={"debug": True}) + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(backend_options={"debug": True}) + ) self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue self.portal.spawn_task(self._run) @@ -279,7 +284,7 @@ def __exit__(self, *args: typing.Any) -> None: try: self.close(1000) finally: - self.portal.stop_from_external_thread() + self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): @@ -446,56 +451,52 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - self.stream_send, self.stream_receive = anyio.create_memory_object_stream() - self.portal = anyio.start_blocking_portal( - backend_options={"debug": True} + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(backend_options={"debug": True}) ) # XXX backend + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.task = self.portal.spawn_task(self.lifespan) - self.portal.call(self.wait_startup) + try: + self.portal.call(self.wait_startup) + except Exception: + self.exit_stack.close() + raise return self def __exit__(self, *args: typing.Any) -> None: try: self.portal.call(self.wait_shutdown) finally: - self.portal.stop_from_external_thread() + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} - async with self.stream_send: + try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + await self.stream_send.send(None) async def wait_startup(self) -> None: - try: - await self.stream_send.send({"type": "lifespan.startup"}) - except anyio.ClosedResourceError: + await self.stream_send.send({"type": "lifespan.startup"}) + message = await self.stream_receive.receive() + if message is None: self.task.result() - return - try: - message = await self.stream_receive.receive() - except anyio.EndOfStream: - self.task.result() - return assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - try: - message = await self.stream_receive.receive() - except anyio.EndOfStream: + message = await self.stream_receive.receive() + if message is None: self.task.result() async def wait_shutdown(self) -> None: - try: + async with self.stream_send: await self.stream_send.send({"type": "lifespan.shutdown"}) - except anyio.ClosedResourceError: - self.task.result() - return - try: message = await self.stream_receive.receive() - except anyio.EndOfStream: - self.task.result() - return - assert message["type"] == "lifespan.shutdown.complete" - self.task.result() + if message is None: + self.task.result() + assert message["type"] == "lifespan.shutdown.complete" From d785513a1f5c35cd831be2a05ecb599152c35476 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:04:11 -0500 Subject: [PATCH 13/59] Remove debug backend option --- starlette/testclient.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index dd199ac1d..4c162dce4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -233,7 +233,7 @@ async def send(message: Message) -> None: context = message["context"] try: - with anyio.start_blocking_portal(backend_options={"debug": True}) as portal: + with anyio.start_blocking_portal() as portal: portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: @@ -266,9 +266,7 @@ def __init__(self, app: ASGI3App, scope: Scope) -> None: self.scope = scope self.accepted_subprotocol = None self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(backend_options={"debug": True}) - ) + self.portal = self.exit_stack.enter_context(anyio.start_blocking_portal()) self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue self.portal.spawn_task(self._run) @@ -453,7 +451,7 @@ def websocket_connect( def __enter__(self) -> "TestClient": self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(backend_options={"debug": True}) + anyio.start_blocking_portal() ) # XXX backend self.stream_send, self.stream_receive = anyio.create_memory_object_stream( math.inf From 58d533181646e504d7151d26fcfb202deb526ea5 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:04:21 -0500 Subject: [PATCH 14/59] Use anyio 3.0.0rc1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cdbc799d6..759073721 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio<3,>=2"], + install_requires=["anyio>=3.0.0rc1"], extras_require={ "full": [ "aiofiles", From 268547dc30b2ba82205110c7348a630d1241c33e Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:34:28 -0500 Subject: [PATCH 15/59] Remove old style executor from GraphQLApp --- starlette/graphql.py | 21 +++++---------------- tests/test_graphql.py | 20 -------------------- 2 files changed, 5 insertions(+), 36 deletions(-) diff --git a/starlette/graphql.py b/starlette/graphql.py index 49adc2f8e..7abd124b7 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -31,29 +31,18 @@ class GraphQLApp: def __init__( self, schema: "graphene.Schema", - executor: typing.Any = None, executor_class: type = None, graphiql: bool = True, ) -> None: self.schema = schema self.graphiql = graphiql - if executor is None: - # New style in 0.10.0. Use 'executor_class'. - # See issue https://github.com/encode/starlette/issues/242 - self.executor = executor - self.executor_class = executor_class - self.is_async = executor_class is not None and issubclass( - executor_class, AsyncioExecutor - ) - else: - # Old style. Use 'executor'. - # We should remove this in the next median/major version bump. - self.executor = executor - self.executor_class = None - self.is_async = isinstance(executor, AsyncioExecutor) + self.executor_class = executor_class + self.is_async = executor_class is not None and issubclass( + executor_class, AsyncioExecutor + ) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if self.executor is None and self.executor_class is not None: + if self.executor_class is not None: self.executor = self.executor_class() request = Request(scope, receive=receive) diff --git a/tests/test_graphql.py b/tests/test_graphql.py index bc6378250..7c9b7cb39 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -147,23 +147,3 @@ def test_graphql_async(): response = client.get("/?query={ hello }") assert response.status_code == 200 assert response.json() == {"data": {"hello": "Hello stranger"}} - - -async_schema = graphene.Schema(query=ASyncQuery) - - -@pytest.fixture -def old_style_async_app(event_loop) -> GraphQLApp: - old_style_async_app = GraphQLApp( - schema=async_schema, executor=AsyncioExecutor(loop=event_loop) - ) - return old_style_async_app - - -@pytest.mark.skip("XXX") -def test_graphql_async_old_style_executor(old_style_async_app: GraphQLApp): - # See https://github.com/encode/starlette/issues/242 - client = TestClient(old_style_async_app) - response = client.get("/?query={ hello }") - assert response.status_code == 200 - assert response.json() == {"data": {"hello": "Hello stranger"}} From 57b2f79754527d4f1ef5c8e1d67df2f827484903 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:41:19 -0500 Subject: [PATCH 16/59] Fix extra import --- tests/test_graphql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 7c9b7cb39..1a51a5b4a 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -1,5 +1,4 @@ import graphene -import pytest from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette From 444a3acdc2ad6eca39afc27f76d7ba4fc10d6722 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:41:42 -0500 Subject: [PATCH 17/59] Don't cancel task scope early --- starlette/middleware/wsgi.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index a459f8aff..44c4d276b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -86,18 +86,15 @@ async def __call__(self, receive: Receive, send: Send) -> None: environ = build_environ(self.scope, body) async with anyio.create_task_group() as task_group: - try: - task_group.spawn(self.sender, send) - async with self.stream_send: - await anyio.run_sync_in_worker_thread( - self.wsgi, environ, self.start_response - ) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) - finally: - task_group.cancel_scope.cancel() + task_group.spawn(self.sender, send) + async with self.stream_send: + await anyio.run_sync_in_worker_thread( + self.wsgi, environ, self.start_response + ) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback( + self.exc_info[1], self.exc_info[2] + ) async def sender(self, send: Send) -> None: async with self.stream_receive: From 4d31a60c6b571267104a361d1ee02056aca592fd Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:45:05 -0500 Subject: [PATCH 18/59] Wait for wsgi sender to finish before exiting --- starlette/middleware/wsgi.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 44c4d276b..083380d01 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -85,21 +85,27 @@ async def __call__(self, receive: Receive, send: Send) -> None: more_body = message.get("more_body", False) environ = build_environ(self.scope, body) + self.send_done = anyio.Event() + async with anyio.create_task_group() as task_group: task_group.spawn(self.sender, send) async with self.stream_send: await anyio.run_sync_in_worker_thread( self.wsgi, environ, self.start_response ) + await self.send_done.wait() if self.exc_info is not None: raise self.exc_info[0].with_traceback( self.exc_info[1], self.exc_info[2] ) async def sender(self, send: Send) -> None: - async with self.stream_receive: - async for message in self.stream_receive: - await send(message) + try: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) + finally: + self.send_done.set() def start_response( self, From 681c3488f2ab8e62698f7476b1322f0c0b956271 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:49:24 -0500 Subject: [PATCH 19/59] Use memory object streams in websocket tests --- tests/test_websockets.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index ffb1a44a8..6c5d0978c 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,5 +1,4 @@ -import asyncio - +import anyio import pytest from starlette import status @@ -208,22 +207,24 @@ async def asgi(receive, send): def test_websocket_concurrency_pattern(): def app(scope): - async def reader(websocket, queue): - async for data in websocket.iter_json(): - await queue.put(data) + stream_send, stream_receive = anyio.create_memory_object_stream() + + async def reader(websocket): + async with stream_send: + async for data in websocket.iter_json(): + await stream_send.send(data) - async def writer(websocket, queue): - while True: - message = await queue.get() - await websocket.send_json(message) + async def writer(websocket): + async with stream_receive: + async for message in stream_receive: + await websocket.send_json(message) async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) - queue = asyncio.Queue() await websocket.accept() await run_until_first_complete( - (reader, {"websocket": websocket, "queue": queue}), - (writer, {"websocket": websocket, "queue": queue}), + (reader, {"websocket": websocket}), + (writer, {"websocket": websocket}), ) await websocket.close() From 01dd81370a7bcd76ebd5a8146ca27988c5643aca Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:53:09 -0500 Subject: [PATCH 20/59] Test on asyncio, asyncio+uvloop, and trio --- requirements.txt | 2 +- tests/conftest.py | 9 +++++++++ tests/test_datastructures.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 tests/conftest.py diff --git a/requirements.txt b/requirements.txt index 55fb0768b..f6ac2e623 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ isort==5.* mypy pytest pytest-cov -pytest-asyncio +trio # Documentation mkdocs diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..bd8a33e47 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest + +@pytest.fixture(params=[ + pytest.param(('asyncio', {'use_uvloop': True}), id='asyncio+uvloop'), + pytest.param(('asyncio', {'use_uvloop': False}), id='asyncio'), + pytest.param(('trio', {'restrict_keyboard_interrupt_to_checkpoints': True}), id='trio') +], autouse=True) +def anyio_backend(request): + return request.param diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index b0e6baf98..bb71ba870 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -217,7 +217,7 @@ class BigUploadFile(UploadFile): spool_max_size = 1024 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_upload_file(): big_file = BigUploadFile("big-file") await big_file.write(b"big-data" * 512) From 9f76d4293df380e842090941557999fb344dbbec Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 18:54:59 -0500 Subject: [PATCH 21/59] Formatting fixes --- tests/conftest.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bd8a33e47..b28f48e3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,15 @@ import pytest -@pytest.fixture(params=[ - pytest.param(('asyncio', {'use_uvloop': True}), id='asyncio+uvloop'), - pytest.param(('asyncio', {'use_uvloop': False}), id='asyncio'), - pytest.param(('trio', {'restrict_keyboard_interrupt_to_checkpoints': True}), id='trio') -], autouse=True) + +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"), + pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"), + pytest.param( + ("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio" + ), + ], + autouse=True, +) def anyio_backend(request): return request.param From 5c8818de1aa934508a6e2125e6ce6c55b4f3b63c Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 21:33:07 -0500 Subject: [PATCH 22/59] run_until_first_complete doesn't need a return --- starlette/concurrency.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 645b96254..fd5162eb7 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -14,19 +14,15 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - result: Any = None async with anyio.create_task_group() as task_group: async def task(_handler: typing.Callable, _kwargs: dict) -> Any: - nonlocal result - result = await _handler(**_kwargs) + await _handler(**_kwargs) task_group.cancel_scope.cancel() for handler, kwargs in args: task_group.spawn(task, handler, kwargs) - return result - async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any From f0e4cd8818c96bc4519395d275f517ce30171e26 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 22:47:41 -0500 Subject: [PATCH 23/59] Fix middleware app call --- starlette/middleware/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 9d56c37ec..48551ae1b 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -36,7 +36,7 @@ async def call_next(self, request: Request) -> Response: async def coro() -> None: async with send_stream: try: - await self.app(scope, recv_stream.receive, send_stream.send) + await self.app(scope, request.receive, send_stream.send) except BaseException as exc: nonlocal coro_exc coro_exc = exc From 376f9db84aad06faab091b5cf69f4effd858d36c Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sat, 27 Mar 2021 23:03:01 -0500 Subject: [PATCH 24/59] Simplify middleware exceptions --- starlette/middleware/base.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 48551ae1b..427e3ecac 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -31,15 +31,9 @@ async def call_next(self, request: Request) -> Response: scope = request.scope task_group = scope["task_group"] - coro_exc = None - async def coro() -> None: async with send_stream: - try: - await self.app(scope, request.receive, send_stream.send) - except BaseException as exc: - nonlocal coro_exc - coro_exc = exc + await self.app(scope, request.receive, send_stream.send) task_group.spawn(coro) @@ -55,8 +49,6 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: async for message in recv_stream: assert message["type"] == "http.response.body" yield message.get("body", b"") - if coro_exc is not None: - raise coro_exc response = StreamingResponse( status_code=message["status"], content=body_stream() From 34da2b427b08ceee585014f842ab08ad6bd68a53 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 28 Mar 2021 11:02:34 -0500 Subject: [PATCH 25/59] Use anyio for websocket test --- tests/test_testclient.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 00f4e0125..a534c8458 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,4 +1,4 @@ -import asyncio +import anyio import pytest @@ -118,13 +118,14 @@ async def respond(websocket): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - asyncio.ensure_future(respond(websocket)) - try: - # this will block as the client does not send us data - # it should not prevent `respond` from executing though - await websocket.receive_json() - except WebSocketDisconnect: - pass + async with anyio.create_task_group() as task_group: + task_group.spawn(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass return asgi From 31cc220062315c4f16784fd152a9eaa022b41fed Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 28 Mar 2021 11:03:18 -0500 Subject: [PATCH 26/59] Set STARLETTE_TESTCLIENT_ASYNC_BACKEND in tests --- tests/conftest.py | 11 +++++++++++ tests/test_applications.py | 2 +- tests/test_database.py | 3 +++ tests/test_graphql.py | 2 +- tests/test_responses.py | 8 ++++---- tests/test_staticfiles.py | 3 +++ 6 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b28f48e3e..af9e14abe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,9 @@ +import os + import pytest +from starlette import config + @pytest.fixture( params=[ @@ -12,4 +16,11 @@ autouse=True, ) def anyio_backend(request): + os.environ["STARLETTE_TESTCLIENT_ASYNC_BACKEND"] = request.param[0] return request.param + + +@pytest.fixture +def no_trio_support(request): + if request.keywords.get("trio"): + pytest.skip("Trio not supported (yet!)") diff --git a/tests/test_applications.py b/tests/test_applications.py index ad8504cbd..62430959c 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -194,7 +194,7 @@ def test_routes(): ] -def test_app_mount(tmpdir): +def test_app_mount(tmpdir, no_trio_support): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") diff --git a/tests/test_database.py b/tests/test_database.py index 258a71ec5..f7280c2c7 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -19,6 +19,9 @@ ) +pytestmark = pytest.mark.usefixtures("no_trio_support") + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 1a51a5b4a..b945a5cfe 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -141,7 +141,7 @@ async def resolve_hello(self, info, name): async_app = GraphQLApp(schema=async_schema, executor_class=AsyncioExecutor) -def test_graphql_async(): +def test_graphql_async(no_trio_support): client = TestClient(async_app) response = client.get("/?query={ hello }") assert response.status_code == 200 diff --git a/tests/test_responses.py b/tests/test_responses.py index 5610ac9d6..46068bfca 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -170,7 +170,7 @@ def test_response_phrase(): assert response.reason == "" -def test_file_response(tmpdir): +def test_file_response(tmpdir, no_trio_support): path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -212,7 +212,7 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" -def test_file_response_with_directory_raises_error(tmpdir): +def test_file_response_with_directory_raises_error(tmpdir, no_trio_support): app = FileResponse(path=tmpdir, filename="example.png") client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: @@ -220,7 +220,7 @@ def test_file_response_with_directory_raises_error(tmpdir): assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir): +def test_file_response_with_missing_file_raises_error(tmpdir, no_trio_support): path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") client = TestClient(app) @@ -229,7 +229,7 @@ def test_file_response_with_missing_file_raises_error(tmpdir): assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmpdir): +def test_file_response_with_chinese_filename(tmpdir, no_trio_support): content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = os.path.join(tmpdir, filename) diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 3c8ff240e..281b94013 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -12,6 +12,9 @@ from starlette.testclient import TestClient +pytestmark = pytest.mark.usefixtures("no_trio_support") + + def test_staticfiles(tmpdir): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: From 73590aa84e212fe728693cca5cadec5988cc8c8c Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 28 Mar 2021 11:13:51 -0500 Subject: [PATCH 27/59] Pass async backend to portal --- starlette/testclient.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 4c162dce4..506f96b63 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -5,6 +5,7 @@ import io import json import math +import os import queue import types import typing @@ -91,11 +92,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( - self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = "" + self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = "", async_backend: str = "" ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path + self.async_backend = async_backend def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any @@ -144,7 +146,7 @@ def send( "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope) + session = WebSocketTestSession(self.app, scope, self.async_backend) raise _Upgrade(session) scope = { @@ -233,7 +235,7 @@ async def send(message: Message) -> None: context = message["context"] try: - with anyio.start_blocking_portal() as portal: + with anyio.start_blocking_portal(self.async_backend) as portal: portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: @@ -261,12 +263,12 @@ async def send(message: Message) -> None: class WebSocketTestSession: - def __init__(self, app: ASGI3App, scope: Scope) -> None: + def __init__(self, app: ASGI3App, scope: Scope, async_backend: str) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(anyio.start_blocking_portal()) + self.portal = self.exit_stack.enter_context(anyio.start_blocking_portal(async_backend)) self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue self.portal.spawn_task(self._run) @@ -368,8 +370,14 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", + async_backend: str = None, ) -> None: super(TestClient, self).__init__() + if async_backend is None: + self.async_backend = os.environ.get("STARLETTE_TESTCLIENT_ASYNC_BACKEND", "asyncio") + else: + self.async_backend = async_backend + if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app @@ -380,6 +388,7 @@ def __init__( asgi_app, raise_server_exceptions=raise_server_exceptions, root_path=root_path, + async_backend=self.async_backend, ) self.mount("http://", adapter) self.mount("https://", adapter) @@ -451,7 +460,7 @@ def websocket_connect( def __enter__(self) -> "TestClient": self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal() + anyio.start_blocking_portal(self.async_backend) ) # XXX backend self.stream_send, self.stream_receive = anyio.create_memory_object_stream( math.inf From 4192bf7bc9de69307690f42b82d01a3c52c381be Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 28 Mar 2021 11:14:53 -0500 Subject: [PATCH 28/59] Formatting fixes --- starlette/testclient.py | 14 +++++++++++--- tests/conftest.py | 2 -- tests/test_staticfiles.py | 1 - tests/test_testclient.py | 1 - 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 506f96b63..c0dcef9e4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -92,7 +92,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( - self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = "", async_backend: str = "" + self, + app: ASGI3App, + raise_server_exceptions: bool = True, + root_path: str = "", + async_backend: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions @@ -268,7 +272,9 @@ def __init__(self, app: ASGI3App, scope: Scope, async_backend: str) -> None: self.scope = scope self.accepted_subprotocol = None self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(anyio.start_blocking_portal(async_backend)) + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(async_backend) + ) self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue self.portal.spawn_task(self._run) @@ -374,7 +380,9 @@ def __init__( ) -> None: super(TestClient, self).__init__() if async_backend is None: - self.async_backend = os.environ.get("STARLETTE_TESTCLIENT_ASYNC_BACKEND", "asyncio") + self.async_backend = os.environ.get( + "STARLETTE_TESTCLIENT_ASYNC_BACKEND", "asyncio" + ) else: self.async_backend = async_backend diff --git a/tests/conftest.py b/tests/conftest.py index af9e14abe..5bc96407e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,6 @@ import pytest -from starlette import config - @pytest.fixture( params=[ diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 281b94013..180b3aba6 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -11,7 +11,6 @@ from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient - pytestmark = pytest.mark.usefixtures("no_trio_support") diff --git a/tests/test_testclient.py b/tests/test_testclient.py index a534c8458..2c3b9422f 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,5 +1,4 @@ import anyio - import pytest from starlette.applications import Starlette From 3a4b4722eced045f1bd9fa58092ed2403e2d55ef Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 28 Mar 2021 19:36:24 -0500 Subject: [PATCH 29/59] Bump anyio --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 759073721..8f407ee99 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0rc1"], + install_requires=["anyio>=3.0.0rc2"], extras_require={ "full": [ "aiofiles", From cc3be48017740993c7dc9cf959e47d572f6af6eb Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:16:52 -0500 Subject: [PATCH 30/59] Cleanup portals and add TestClient.async_backend --- starlette/testclient.py | 71 ++++++++++++++++++++++------------------- tests/conftest.py | 16 +++++----- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index c0dcef9e4..71af97e13 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -5,7 +5,6 @@ import io import json import math -import os import queue import types import typing @@ -13,6 +12,7 @@ import anyio import requests +from anyio.streams.stapled import StapledObjectStream from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -94,9 +94,9 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, + async_backend: typing.Dict[str, typing.Any], raise_server_exceptions: bool = True, root_path: str = "", - async_backend: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions @@ -239,7 +239,7 @@ async def send(message: Message) -> None: context = message["context"] try: - with anyio.start_blocking_portal(self.async_backend) as portal: + with anyio.start_blocking_portal(**self.async_backend) as portal: portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: @@ -267,23 +267,31 @@ async def send(message: Message) -> None: class WebSocketTestSession: - def __init__(self, app: ASGI3App, scope: Scope, async_backend: str) -> None: + def __init__( + self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any] + ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(async_backend) - ) + self.async_backend = async_backend self._receive_queue = queue.Queue() # type: queue.Queue self._send_queue = queue.Queue() # type: queue.Queue - self.portal.spawn_task(self._run) - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + try: + self.portal.spawn_task(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) return self def __exit__(self, *args: typing.Any) -> None: @@ -307,6 +315,7 @@ async def _run(self) -> None: await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) + raise async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): @@ -370,22 +379,17 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. + #: These options are passed to `anyio.start_blocking_portal()` + async_backend = {"backend": "asyncio", "backend_options": {}} # type: typing.Dict[str, typing.Any] + def __init__( self, app: typing.Union[ASGI2App, ASGI3App], base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - async_backend: str = None, ) -> None: super(TestClient, self).__init__() - if async_backend is None: - self.async_backend = os.environ.get( - "STARLETTE_TESTCLIENT_ASYNC_BACKEND", "asyncio" - ) - else: - self.async_backend = async_backend - if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app @@ -394,9 +398,9 @@ def __init__( asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, + self.async_backend, raise_server_exceptions=raise_server_exceptions, root_path=root_path, - async_backend=self.async_backend, ) self.mount("http://", adapter) self.mount("https://", adapter) @@ -468,13 +472,16 @@ def websocket_connect( def __enter__(self) -> "TestClient": self.exit_stack = contextlib.ExitStack() self.portal = self.exit_stack.enter_context( - anyio.start_blocking_portal(self.async_backend) - ) # XXX backend - self.stream_send, self.stream_receive = anyio.create_memory_object_stream( - math.inf + anyio.start_blocking_portal(**self.async_backend) + ) + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) ) - self.task = self.portal.spawn_task(self.lifespan) try: + self.task = self.portal.spawn_task(self.lifespan) self.portal.call(self.wait_startup) except Exception: self.exit_stack.close() @@ -495,8 +502,8 @@ async def lifespan(self) -> None: await self.stream_send.send(None) async def wait_startup(self) -> None: - await self.stream_send.send({"type": "lifespan.startup"}) - message = await self.stream_receive.receive() + await self.stream_receive.send({"type": "lifespan.startup"}) + message = await self.stream_send.receive() if message is None: self.task.result() assert message["type"] in ( @@ -504,14 +511,14 @@ async def wait_startup(self) -> None: "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - message = await self.stream_receive.receive() + message = await self.stream_send.receive() if message is None: self.task.result() async def wait_shutdown(self) -> None: async with self.stream_send: - await self.stream_send.send({"type": "lifespan.shutdown"}) - message = await self.stream_receive.receive() + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await self.stream_send.receive() if message is None: self.task.result() assert message["type"] == "lifespan.shutdown.complete" diff --git a/tests/conftest.py b/tests/conftest.py index 5bc96407e..d1f3ba8e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,21 @@ -import os - import pytest +from starlette.testclient import TestClient + @pytest.fixture( params=[ - pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"), - pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"), pytest.param( - ("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio" + {"backend": "asyncio", "backend_options": {"use_uvloop": False}}, + id="asyncio", ), + pytest.param({"backend": "trio", "backend_options": {}}, id="trio"), ], autouse=True, ) -def anyio_backend(request): - os.environ["STARLETTE_TESTCLIENT_ASYNC_BACKEND"] = request.param[0] - return request.param +def anyio_backend(request, monkeypatch): + monkeypatch.setattr(TestClient, "async_backend", request.param) + return request.param["backend"] @pytest.fixture From 9b6e722c5ac12fa05efda705321b55d9204447ed Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:17:14 -0500 Subject: [PATCH 31/59] Use anyio.run_async_from_thread to send from worker thread --- starlette/middleware/wsgi.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 083380d01..20aeee1f4 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -85,27 +85,19 @@ async def __call__(self, receive: Receive, send: Send) -> None: more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - self.send_done = anyio.Event() - async with anyio.create_task_group() as task_group: task_group.spawn(self.sender, send) async with self.stream_send: await anyio.run_sync_in_worker_thread( self.wsgi, environ, self.start_response ) - await self.send_done.wait() - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: - try: - async with self.stream_receive: - async for message in self.stream_receive: - await send(message) - finally: - self.send_done.set() + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) def start_response( self, @@ -122,7 +114,8 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - self.stream_send.send_nowait( + anyio.run_async_from_thread( + self.stream_send.send, { "type": "http.response.start", "status": status_code, @@ -132,8 +125,11 @@ def start_response( def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - self.stream_send.send_nowait( + anyio.run_async_from_thread( + self.stream_send.send, {"type": "http.response.body", "body": chunk, "more_body": True}, ) - self.stream_send.send_nowait({"type": "http.response.body", "body": b""}) + anyio.run_async_from_thread( + self.stream_send.send, {"type": "http.response.body", "body": b""} + ) From b8c43cf0b9413eb65bc7d0cedd3ce2eaa29b9328 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:17:40 -0500 Subject: [PATCH 32/59] Use websocket_connect as context manager --- tests/middleware/test_errors.py | 3 ++- tests/test_authentication.py | 16 +++++++++++----- tests/test_exceptions.py | 3 ++- tests/test_routing.py | 6 ++++-- tests/test_websockets.py | 6 ++++-- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index c178ef9da..768857bc5 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -67,4 +67,5 @@ async def app(scope, receive, send): with pytest.raises(RuntimeError): client = TestClient(app) - client.websocket_connect("/") + with client.websocket_connect("/"): + pass diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3373f67c5..981aa65ab 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -261,10 +261,14 @@ def test_authentication_required(): def test_websocket_authentication_required(): with TestClient(app) as client: with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws") + with client.websocket_connect("/ws"): + pass with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}) + with client.websocket_connect( + "/ws", headers={"Authorization": "basic foobar"} + ): + pass with client.websocket_connect( "/ws", auth=("tomchristie", "example") @@ -273,12 +277,14 @@ def test_websocket_authentication_required(): assert data == {"authenticated": True, "user": "tomchristie"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws/decorated") + with client.websocket_connect("/ws/decorated"): + pass with pytest.raises(WebSocketDisconnect): - client.websocket_connect( + with client.websocket_connect( "/ws/decorated", headers={"Authorization": "basic foobar"} - ) + ): + pass with client.websocket_connect( "/ws/decorated", auth=("tomchristie", "example") diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 841c9a5cf..1dfd437f3 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -54,7 +54,8 @@ def test_not_modified(): def test_websockets_should_raise(): with pytest.raises(RuntimeError): - client.websocket_connect("/runtime_error") + with client.websocket_connect("/runtime_error"): + pass def test_handled_exc_after_response(): diff --git a/tests/test_routing.py b/tests/test_routing.py index 8927c60cd..e33ce83fc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -286,7 +286,8 @@ def test_protocol_switch(): assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/404") + with client.websocket_connect("/404"): + pass ok = PlainTextResponse("OK") @@ -492,7 +493,8 @@ def test_standalone_ws_route_does_not_match(): app = WebSocketRoute("/", ws_helloworld) client = TestClient(app) with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/invalid") + with client.websocket_connect("/invalid"): + pass def test_lifespan_async(): diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 6c5d0978c..df6421b35 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -284,7 +284,8 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: - client.websocket_connect("/") + with client.websocket_connect("/"): + pass assert exc.value.code == status.WS_1001_GOING_AWAY @@ -312,7 +313,8 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(AssertionError): - client.websocket_connect("/123?a=abc") + with client.websocket_connect("/123?a=abc"): + pass def test_duplicate_close(): From d51d5ff1a39b8ff0497a11847907b73e0ebb1715 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:17:52 -0500 Subject: [PATCH 33/59] Document changes in TestClient --- docs/testclient.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/testclient.md b/docs/testclient.md index 61f7201c6..c21380498 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -31,6 +31,21 @@ application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. +### Selecting the Async backend + +`TestClient.async_backend` is a dictionary which allows you to set the options +for the backend used to run tests. These options are passed to +`anyio.start_blocking_portal()`. By default, `asyncio` is used. + +To run `Trio`, set `async_backend["backend"] = "trio"`, for example: + +```python +def test_app() + client = TestClient(app) + client.async_backend["backend"] = "trio" + ... +``` + ### Testing WebSocket sessions You can also test websocket sessions with the test client. @@ -72,6 +87,8 @@ always raised by the test client. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. +`websocket_connect()` must be used as a context manager (in a `with` block). + #### Sending data * `.send_text(data)` - Send the given text to the application. From 82431f4cb84303328c12ee20b3a0df6b03c8d33f Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:18:16 -0500 Subject: [PATCH 34/59] Formatting fix --- starlette/testclient.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 71af97e13..e01d53b1f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -380,7 +380,10 @@ class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. #: These options are passed to `anyio.start_blocking_portal()` - async_backend = {"backend": "asyncio", "backend_options": {}} # type: typing.Dict[str, typing.Any] + async_backend = { + "backend": "asyncio", + "backend_options": {}, + } # type: typing.Dict[str, typing.Any] def __init__( self, From 250423771d8947f9a5394348bc23629aad511e58 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Mon, 29 Mar 2021 13:24:56 -0500 Subject: [PATCH 35/59] Fix websocket raises coverage --- tests/middleware/test_errors.py | 2 +- tests/test_authentication.py | 8 ++++---- tests/test_exceptions.py | 2 +- tests/test_routing.py | 4 ++-- tests/test_websockets.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 768857bc5..28b2a7ba3 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -68,4 +68,4 @@ async def app(scope, receive, send): with pytest.raises(RuntimeError): client = TestClient(app) with client.websocket_connect("/"): - pass + pass # pragma: nocover diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 981aa65ab..8ee87932a 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -262,13 +262,13 @@ def test_websocket_authentication_required(): with TestClient(app) as client: with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws"): - pass + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): with client.websocket_connect( "/ws", headers={"Authorization": "basic foobar"} ): - pass + pass # pragma: nocover with client.websocket_connect( "/ws", auth=("tomchristie", "example") @@ -278,13 +278,13 @@ def test_websocket_authentication_required(): with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/ws/decorated"): - pass + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): with client.websocket_connect( "/ws/decorated", headers={"Authorization": "basic foobar"} ): - pass + pass # pragma: nocover with client.websocket_connect( "/ws/decorated", auth=("tomchristie", "example") diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1dfd437f3..bab6961b5 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -55,7 +55,7 @@ def test_not_modified(): def test_websockets_should_raise(): with pytest.raises(RuntimeError): with client.websocket_connect("/runtime_error"): - pass + pass # pragma: nocover def test_handled_exc_after_response(): diff --git a/tests/test_routing.py b/tests/test_routing.py index e33ce83fc..5ef07d113 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -287,7 +287,7 @@ def test_protocol_switch(): with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/404"): - pass + pass # pragma: nocover ok = PlainTextResponse("OK") @@ -494,7 +494,7 @@ def test_standalone_ws_route_does_not_match(): client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/invalid"): - pass + pass # pragma: nocover def test_lifespan_async(): diff --git a/tests/test_websockets.py b/tests/test_websockets.py index df6421b35..305dc64d7 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -285,7 +285,7 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): - pass + pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY @@ -314,7 +314,7 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): - pass + pass # pragma: nocover def test_duplicate_close(): @@ -330,7 +330,7 @@ async def asgi(receive, send): client = TestClient(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass + pass # pragma: nocover def test_duplicate_disconnect(): From cf915bc58da66ddd2da678908b0fdae605d1577d Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Tue, 30 Mar 2021 09:11:56 -0500 Subject: [PATCH 36/59] Update to anyio 3.0.0rc3 and replace aiofiles --- docs/index.md | 2 -- setup.py | 3 +-- starlette/concurrency.py | 4 ++-- starlette/middleware/wsgi.py | 10 ++++------ starlette/responses.py | 17 ++++------------- starlette/staticfiles.py | 6 +++--- tests/test_applications.py | 2 +- tests/test_responses.py | 8 ++++---- tests/test_staticfiles.py | 2 -- 9 files changed, 19 insertions(+), 35 deletions(-) diff --git a/docs/index.md b/docs/index.md index 4ae77f0e6..db82e0d36 100644 --- a/docs/index.md +++ b/docs/index.md @@ -82,7 +82,6 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam Starlette does not have any hard dependencies, but the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -161,7 +160,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ diff --git a/setup.py b/setup.py index 8f407ee99..b6ec21a1b 100644 --- a/setup.py +++ b/setup.py @@ -48,10 +48,9 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0rc2"], + install_requires=["anyio>=3.0.0rc3"], extras_require={ "full": [ - "aiofiles", "graphene", "itsdangerous", "jinja2", diff --git a/starlette/concurrency.py b/starlette/concurrency.py index fd5162eb7..b06aac86d 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -36,7 +36,7 @@ async def run_in_threadpool( elif kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await anyio.run_sync_in_worker_thread(func, *args) + return await anyio.to_thread.run_sync(func, *args) class _StopIteration(Exception): @@ -56,6 +56,6 @@ def _next(iterator: Iterator) -> Any: async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: while True: try: - yield await anyio.run_sync_in_worker_thread(_next, iterator) + yield await anyio.to_thread.run_sync(_next, iterator) except _StopIteration: break diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 20aeee1f4..cec4eb187 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -88,9 +88,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: task_group.spawn(self.sender, send) async with self.stream_send: - await anyio.run_sync_in_worker_thread( - self.wsgi, environ, self.start_response - ) + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) if self.exc_info is not None: raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) @@ -114,7 +112,7 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - anyio.run_async_from_thread( + anyio.from_thread.run( self.stream_send.send, { "type": "http.response.start", @@ -125,11 +123,11 @@ def start_response( def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - anyio.run_async_from_thread( + anyio.from_thread.run( self.stream_send.send, {"type": "http.response.body", "body": chunk, "more_body": True}, ) - anyio.run_async_from_thread( + anyio.from_thread.run( self.stream_send.send, {"type": "http.response.body", "body": b""} ) diff --git a/starlette/responses.py b/starlette/responses.py index ff122fba1..a8db7e09b 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -9,6 +9,8 @@ from mimetypes import guess_type as mimetypes_guess_type from urllib.parse import quote, quote_plus +import anyio + from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool, run_until_first_complete from starlette.datastructures import URL, MutableHeaders @@ -17,13 +19,6 @@ # Workaround for adding samesite support to pre 3.8 python http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore -try: - import aiofiles - from aiofiles.os import stat as aio_stat -except ImportError: # pragma: nocover - aiofiles = None # type: ignore - aio_stat = None # type: ignore - # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on None: - assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse" self.path = path self.status_code = status_code self.filename = filename @@ -280,7 +274,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.stat_result is None: try: - stat_result = await aio_stat(self.path) + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") @@ -298,10 +292,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - # Tentatively ignoring type checking failure to work around the wrong type - # definitions for aiofile that come with typeshed. See - # https://github.com/python/typeshed/pull/4650 - async with aiofiles.open(self.path, mode="rb") as file: # type: ignore + async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 225e45745..013009830 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,7 +4,7 @@ import typing from email.utils import parsedate -from aiofiles.os import stat as aio_stat +import anyio from starlette.datastructures import URL, Headers from starlette.responses import ( @@ -154,7 +154,7 @@ async def lookup_path( # directory. continue try: - stat_result = await aio_stat(full_path) + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) return (full_path, stat_result) except FileNotFoundError: pass @@ -187,7 +187,7 @@ async def check_config(self) -> None: return try: - stat_result = await aio_stat(self.directory) + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." diff --git a/tests/test_applications.py b/tests/test_applications.py index 62430959c..ad8504cbd 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -194,7 +194,7 @@ def test_routes(): ] -def test_app_mount(tmpdir, no_trio_support): +def test_app_mount(tmpdir): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") diff --git a/tests/test_responses.py b/tests/test_responses.py index 46068bfca..5610ac9d6 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -170,7 +170,7 @@ def test_response_phrase(): assert response.reason == "" -def test_file_response(tmpdir, no_trio_support): +def test_file_response(tmpdir): path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -212,7 +212,7 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" -def test_file_response_with_directory_raises_error(tmpdir, no_trio_support): +def test_file_response_with_directory_raises_error(tmpdir): app = FileResponse(path=tmpdir, filename="example.png") client = TestClient(app) with pytest.raises(RuntimeError) as exc_info: @@ -220,7 +220,7 @@ def test_file_response_with_directory_raises_error(tmpdir, no_trio_support): assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir, no_trio_support): +def test_file_response_with_missing_file_raises_error(tmpdir): path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") client = TestClient(app) @@ -229,7 +229,7 @@ def test_file_response_with_missing_file_raises_error(tmpdir, no_trio_support): assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmpdir, no_trio_support): +def test_file_response_with_chinese_filename(tmpdir): content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = os.path.join(tmpdir, filename) diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 180b3aba6..3c8ff240e 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -11,8 +11,6 @@ from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient -pytestmark = pytest.mark.usefixtures("no_trio_support") - def test_staticfiles(tmpdir): path = os.path.join(tmpdir, "example.txt") From 72586ba2585f88da66d2372a04c47d886b709459 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 9 Apr 2021 19:10:09 -0500 Subject: [PATCH 37/59] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Alex Grönholm --- setup.py | 2 +- starlette/concurrency.py | 2 +- starlette/middleware/base.py | 2 +- starlette/middleware/wsgi.py | 2 +- starlette/testclient.py | 4 ++-- tests/test_testclient.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index b6ec21a1b..61331e89c 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def get_packages(package): packages=get_packages("starlette"), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0rc3"], + install_requires=["anyio>=3.0.0rc4"], extras_require={ "full": [ "graphene", diff --git a/starlette/concurrency.py b/starlette/concurrency.py index b06aac86d..a32d96f2c 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -21,7 +21,7 @@ async def task(_handler: typing.Callable, _kwargs: dict) -> Any: task_group.cancel_scope.cancel() for handler, kwargs in args: - task_group.spawn(task, handler, kwargs) + task_group.start_soon(task, handler, kwargs) async def run_in_threadpool( diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 427e3ecac..16d50649a 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -35,7 +35,7 @@ async def coro() -> None: async with send_stream: await self.app(scope, request.receive, send_stream.send) - task_group.spawn(coro) + task_group.start_soon(coro) try: message = await recv_stream.receive() diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index cec4eb187..2c2d3f96b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -86,7 +86,7 @@ async def __call__(self, receive: Receive, send: Send) -> None: environ = build_environ(self.scope, body) async with anyio.create_task_group() as task_group: - task_group.spawn(self.sender, send) + task_group.start_soon(self.sender, send) async with self.stream_send: await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) if self.exc_info is not None: diff --git a/starlette/testclient.py b/starlette/testclient.py index 5bfde33de..f4a960a3f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -284,7 +284,7 @@ def __enter__(self) -> "WebSocketTestSession": ) try: - self.portal.spawn_task(self._run) + self.portal.start_task_soon(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) @@ -484,7 +484,7 @@ def __enter__(self) -> "TestClient": *anyio.create_memory_object_stream(math.inf) ) try: - self.task = self.portal.spawn_task(self.lifespan) + self.task = self.portal.start_task_soon(self.lifespan) self.portal.call(self.wait_startup) except Exception: self.exit_stack.close() diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 2c3b9422f..86f36e172 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -118,7 +118,7 @@ async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async with anyio.create_task_group() as task_group: - task_group.spawn(respond, websocket) + task_group.start_soon(respond, websocket) try: # this will block as the client does not send us data # it should not prevent `respond` from executing though From 89e2dae4d911a4d131be9e959a33e9b0f722c13a Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 21 Apr 2021 09:43:02 -0500 Subject: [PATCH 38/59] Bump to require anyio 3.0.0 final --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3a445cc67..a687ad861 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, - install_requires=["anyio>=3.0.0rc4"], + install_requires=["anyio>=3.0.0,<4"], extras_require={ "full": [ "graphene", From f62a2ec1c2e2a55fb1302ab96c9740278e540854 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 21 Apr 2021 15:09:37 -0500 Subject: [PATCH 39/59] Remove mention of aiofiles in README.md --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index bbd7a0a3e..3a5f08cc7 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,6 @@ For a more complete example, see [encode/starlette-example](https://github.com/e Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -170,7 +169,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ From fc60420e88f8d25900632e1bce1cc403bff84e38 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Wed, 12 May 2021 08:39:41 -0500 Subject: [PATCH 40/59] Pin jinja2 to releases before 3 due to DeprecationWarnings --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f6ac2e623..62ca1e8f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ aiofiles graphene itsdangerous -jinja2 +jinja2<3 python-multipart pyyaml requests From 27283aa5d2248dbc42222b0c884851aa1f4f6dd8 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 09:30:06 -0500 Subject: [PATCH 41/59] Add task_group as application attribute --- starlette/applications.py | 7 +++++-- starlette/middleware/base.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 10bd075a3..01a13bb0e 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,6 +1,7 @@ import typing import anyio +from anyio.abc import TaskGroup from starlette.datastructures import State, URLPath from starlette.exceptions import ExceptionMiddleware @@ -38,6 +39,8 @@ class Starlette: standard functions, or async functions. """ + task_group: TaskGroup + def __init__( self, debug: bool = False, @@ -111,8 +114,8 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self - task_group = scope["task_group"] = anyio.create_task_group() - async with task_group: + self.task_group = anyio.create_task_group() + async with self.task_group: await self.middleware_stack(scope, receive, send) # The following usages are now discouraged in favour of configuration diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 16d50649a..aa124b448 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -29,13 +29,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(self, request: Request) -> Response: send_stream, recv_stream = anyio.create_memory_object_stream() scope = request.scope - task_group = scope["task_group"] async def coro() -> None: async with send_stream: await self.app(scope, request.receive, send_stream.send) - task_group.start_soon(coro) + scope["app"].task_group.start_soon(coro) try: message = await recv_stream.receive() From 3cce6a91fec7926beead39a47d909e2aaf7674a8 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 09:30:21 -0500 Subject: [PATCH 42/59] Remove run_until_first_complete --- starlette/concurrency.py | 11 ----------- starlette/responses.py | 15 ++++++++++----- tests/test_websockets.py | 8 +++----- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index a32d96f2c..1ef110ea6 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -13,17 +13,6 @@ T = typing.TypeVar("T") -async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - async with anyio.create_task_group() as task_group: - - async def task(_handler: typing.Callable, _kwargs: dict) -> Any: - await _handler(**_kwargs) - task_group.cancel_scope.cancel() - - for handler, kwargs in args: - task_group.start_soon(task, handler, kwargs) - - async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: diff --git a/starlette/responses.py b/starlette/responses.py index e8a68b828..71b830615 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -6,13 +6,14 @@ import sys import typing from email.utils import formatdate +from functools import partial from mimetypes import guess_type as mimetypes_guess_type from urllib.parse import quote import anyio from starlette.background import BackgroundTask -from starlette.concurrency import iterate_in_threadpool, run_until_first_complete +from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send @@ -216,10 +217,14 @@ async def stream_response(self, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await run_until_first_complete( - (self.stream_response, {"send": send}), - (self.listen_for_disconnect, {"receive": receive}), - ) + async with anyio.create_task_group() as task_group: + + async def wrap(coro: typing.Callable[..., typing.Awaitable]) -> None: + await coro() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + task_group.start_soon(wrap, partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 305dc64d7..584bfd5f8 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -2,7 +2,6 @@ import pytest from starlette import status -from starlette.concurrency import run_until_first_complete from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -222,10 +221,9 @@ async def writer(websocket): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - await run_until_first_complete( - (reader, {"websocket": websocket}), - (writer, {"websocket": websocket}), - ) + async with anyio.create_task_group() as task_group: + task_group.start_soon(reader, websocket) + task_group.start_soon(writer, websocket) await websocket.close() return asgi From c4d49a70585a6ad99cfb69050d831d3eeb3c74b8 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 09:35:42 -0500 Subject: [PATCH 43/59] Undo jinja pin --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 62ca1e8f0..f6ac2e623 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ aiofiles graphene itsdangerous -jinja2<3 +jinja2 python-multipart pyyaml requests From 4dd8c5de79df2ee56959dba7bc20bbcdbb4d833d Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 10:31:10 -0500 Subject: [PATCH 44/59] Refactor anyio.sleep into an event --- starlette/testclient.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index f4a960a3f..4de2e46a2 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -169,17 +169,16 @@ def send( request_complete = False response_started = False - response_complete = False + response_complete: anyio.Event raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] template = None context = None async def receive() -> Message: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: - while not response_complete: - await anyio.sleep(0.0001) + await response_complete.wait() return {"type": "http.disconnect"} body = request.body @@ -203,7 +202,7 @@ async def receive() -> Message: return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( @@ -225,7 +224,7 @@ async def send(message: Message) -> None: response_started ), 'Received "http.response.body" without "http.response.start".' assert ( - not response_complete + not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) @@ -233,13 +232,14 @@ async def send(message: Message) -> None: raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) - response_complete = True + response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] try: with anyio.start_blocking_portal(**self.async_backend) as portal: + response_complete = portal.call(anyio.Event) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: From dde5079bdf59c1f974dff30d888ba1062faf9145 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 10:43:31 -0500 Subject: [PATCH 45/59] Use one less task in test_websocket_concurrency_pattern --- tests/test_websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 584bfd5f8..63ecd050a 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -223,7 +223,7 @@ async def asgi(receive, send): await websocket.accept() async with anyio.create_task_group() as task_group: task_group.start_soon(reader, websocket) - task_group.start_soon(writer, websocket) + await writer(websocket) await websocket.close() return asgi From df53965c8fc4ecbe4aceaeb6d7248ab8a92de54b Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 14:29:33 -0500 Subject: [PATCH 46/59] Apply review suggestions --- starlette/responses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 71b830615..ba19120e0 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -219,12 +219,12 @@ async def stream_response(self, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: - async def wrap(coro: typing.Callable[..., typing.Awaitable]) -> None: + async def wrap(coro: typing.Callable[[], typing.Coroutine]) -> None: await coro() task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) - task_group.start_soon(wrap, partial(self.listen_for_disconnect, receive)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() From 6e0f05f9a0cfd3099b8caf694385c92ff48c244c Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 13 May 2021 14:31:21 -0500 Subject: [PATCH 47/59] Rename argument --- starlette/responses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index ba19120e0..abb8d0140 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -219,8 +219,8 @@ async def stream_response(self, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: - async def wrap(coro: typing.Callable[[], typing.Coroutine]) -> None: - await coro() + async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() task_group.cancel_scope.cancel() task_group.start_soon(wrap, partial(self.stream_response, send)) From 3a359e39034908f925ed6c85e85bf78c80ffa6d7 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 24 May 2021 23:31:11 +0100 Subject: [PATCH 48/59] fix start_task_soon type --- starlette/testclient.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 4de2e46a2..e45dc212b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -8,6 +8,7 @@ import queue import types import typing +from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit import anyio @@ -284,7 +285,7 @@ def __enter__(self) -> "WebSocketTestSession": ) try: - self.portal.start_task_soon(self._run) + _: "Future[None]" = self.portal.start_task_soon(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) @@ -385,6 +386,8 @@ class TestClient(requests.Session): "backend_options": {}, } # type: typing.Dict[str, typing.Any] + task: "Future[None]" + def __init__( self, app: typing.Union[ASGI2App, ASGI3App], From 5c77b7d1160bcbd1ce6fd265c06e4fdad81919b2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 24 May 2021 23:33:27 +0100 Subject: [PATCH 49/59] fix BaseHTTPMiddleware when used without Starlette --- starlette/applications.py | 9 +---- starlette/middleware/base.py | 64 ++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index 01a13bb0e..34c3e38bd 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,8 +1,5 @@ import typing -import anyio -from anyio.abc import TaskGroup - from starlette.datastructures import State, URLPath from starlette.exceptions import ExceptionMiddleware from starlette.middleware import Middleware @@ -39,8 +36,6 @@ class Starlette: standard functions, or async functions. """ - task_group: TaskGroup - def __init__( self, debug: bool = False, @@ -114,9 +109,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self - self.task_group = anyio.create_task_group() - async with self.task_group: - await self.middleware_stack(scope, receive, send) + await self.middleware_stack(scope, receive, send) # The following usages are now discouraged in favour of configuration #  during Starlette.__init__(...) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index aa124b448..a542dad66 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -22,38 +22,38 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, self.call_next) - await response(scope, receive, send) - - async def call_next(self, request: Request) -> Response: - send_stream, recv_stream = anyio.create_memory_object_stream() - scope = request.scope - - async def coro() -> None: - async with send_stream: - await self.app(scope, request.receive, send_stream.send) - - scope["app"].task_group.start_soon(coro) - - try: - message = await recv_stream.receive() - except anyio.EndOfStream: - raise RuntimeError("No response returned.") - - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - async with recv_stream: - async for message in recv_stream: - assert message["type"] == "http.response.body" - yield message.get("body", b"") - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response + async def call_next(request: Request) -> Response: + send_stream, recv_stream = anyio.create_memory_object_stream() + + async def coro() -> None: + async with send_stream: + await self.app(scope, request.receive, send_stream.send) + + task_group.start_soon(coro) + + try: + message = await recv_stream.receive() + except anyio.EndOfStream: + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + yield message.get("body", b"") + + response = StreamingResponse( + status_code=message["status"], content=body_stream() + ) + response.raw_headers = message["headers"] + return response + + async with anyio.create_task_group() as task_group: + request = Request(scope, receive=receive) + response = await self.dispatch_func(request, call_next) + await response(scope, receive, send) async def dispatch( self, request: Request, call_next: RequestResponseEndpoint From 6a3f94dcdb331025a5c4f9a68c17777ba8f07da5 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 27 May 2021 09:50:42 -0500 Subject: [PATCH 50/59] Testclient receive() is a non-trapping function if the response is already complete This allows for a zero deadline when waiting for a disconnect message --- starlette/requests.py | 4 ++-- starlette/testclient.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index b71bdd1c1..d98c1db29 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -253,8 +253,8 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} - with anyio.move_on_after(0.001): - # XXX: to small of a deadline and this blocks + + with anyio.move_on_after(0): message = await self._receive() if message.get("type") == "http.disconnect": diff --git a/starlette/testclient.py b/starlette/testclient.py index e45dc212b..92f4217b3 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -179,7 +179,8 @@ async def receive() -> Message: nonlocal request_complete if request_complete: - await response_complete.wait() + if not response_complete.is_set(): + await response_complete.wait() return {"type": "http.disconnect"} body = request.body From 5667a4bdcde24d668c4a31062984d88fdc0ec628 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 27 May 2021 10:04:45 -0500 Subject: [PATCH 51/59] Use variable annotation for async_backend --- starlette/testclient.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 92f4217b3..e8482a41f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -382,10 +382,10 @@ class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. #: These options are passed to `anyio.start_blocking_portal()` - async_backend = { + async_backend: typing.Dict[str, typing.Any] = { "backend": "asyncio", "backend_options": {}, - } # type: typing.Dict[str, typing.Any] + } task: "Future[None]" From 19685dbc0a3f736ead52837ff93b37c9fec5fc25 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 27 May 2021 10:09:14 -0500 Subject: [PATCH 52/59] Update docs regarding dependency on anyio --- docs/index.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index db82e0d36..b9692a1fb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,7 +32,7 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. ## Requirements @@ -79,7 +79,7 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following dependencies are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. From a1ceb355491986cea3b24b7012a83ff2261119e2 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 27 May 2021 14:26:48 -0500 Subject: [PATCH 53/59] Use CancelScope instead of move_on_after in request.is_disconnected --- starlette/requests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/starlette/requests.py b/starlette/requests.py index d98c1db29..54ed8611e 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -254,7 +254,9 @@ async def is_disconnected(self) -> bool: if not self._is_disconnected: message: Message = {} - with anyio.move_on_after(0): + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() message = await self._receive() if message.get("type") == "http.disconnect": From 63cfcb9ef029406e39852a09592199309a139f22 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 13 Jun 2021 10:47:17 -0500 Subject: [PATCH 54/59] Cancel task group after returning middleware response Add test for https://github.com/encode/starlette/issues/1022 --- starlette/middleware/base.py | 1 + tests/middleware/test_base.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index a542dad66..77ba66925 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -54,6 +54,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 048dd9ffb..df8901934 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -143,3 +143,18 @@ def homepage(request): def test_middleware_repr(): middleware = Middleware(CustomMiddleware) assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_fully_evaluated_response(): + # Test for https://github.com/encode/starlette/issues/1022 + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + app = Starlette() + app.add_middleware(CustomMiddleware) + + client = TestClient(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" From 6208ca5e356cb9acc605723e5b05e91bfbe1c391 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 13 Jun 2021 10:54:53 -0500 Subject: [PATCH 55/59] Add link to anyio backend options in testclient docs --- docs/testclient.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/testclient.md b/docs/testclient.md index c21380498..f37858401 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -35,7 +35,8 @@ case you should use `client = TestClient(app, raise_server_exceptions=False)`. `TestClient.async_backend` is a dictionary which allows you to set the options for the backend used to run tests. These options are passed to -`anyio.start_blocking_portal()`. By default, `asyncio` is used. +`anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) +for more information about backend options. By default, `asyncio` is used. To run `Trio`, set `async_backend["backend"] = "trio"`, for example: From e0c99676119f97d1e35e995a9b33649a69d5da68 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Sun, 13 Jun 2021 11:06:33 -0500 Subject: [PATCH 56/59] Add types-dataclasses --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 96d694145..ae3d91f26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ types-requests types-contextvars types-aiofiles types-PyYAML +types-dataclasses pytest pytest-cov trio From 2b9dd22173e7ddb552e08503433c700601a417d0 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 17 Jun 2021 11:09:16 -0500 Subject: [PATCH 57/59] Re-implement starlette.concurrency.run_until_first_complete and add a test --- starlette/concurrency.py | 10 ++++++++++ tests/test_concurrency.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 tests/test_concurrency.py diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 1ef110ea6..a62a5e507 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -13,6 +13,16 @@ T = typing.TypeVar("T") +async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: + async def run(handler: typing.Callable[..., typing.Coroutine]) -> None: + await handler() + tg.cancel_scope.cancel() + + async with anyio.create_task_group() as tg: + for handler, kwargs in args: + tg.start_soon(run, functools.partial(handler, **kwargs)) + + async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 000000000..cc5eba974 --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,22 @@ +import anyio +import pytest + +from starlette.concurrency import run_until_first_complete + + +@pytest.mark.anyio +async def test_run_until_first_complete(): + task1_finished = anyio.Event() + task2_finished = anyio.Event() + + async def task1(): + task1_finished.set() + + async def task2(): + await task1_finished.wait() + await anyio.sleep(0) # pragma: nocover + task2_finished.set() # pragma: nocover + + await run_until_first_complete((task1, {}), (task2, {})) + assert task1_finished.is_set() + assert not task2_finished.is_set() From 643d107a88f1e26b0c0485b8ca155488f75fabf1 Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Thu, 17 Jun 2021 11:13:04 -0500 Subject: [PATCH 58/59] Fix type on handler callable --- starlette/concurrency.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index a62a5e507..f1dfb0a13 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -14,7 +14,7 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - async def run(handler: typing.Callable[..., typing.Coroutine]) -> None: + async def run(handler: typing.Callable[[], typing.Coroutine]) -> None: await handler() tg.cancel_scope.cancel() From d0ca3f2bca14aa58d813e52b691b8ffd462a602b Mon Sep 17 00:00:00 2001 From: Jordan Speicher Date: Fri, 18 Jun 2021 09:35:10 -0500 Subject: [PATCH 59/59] Apply review comments to clarify run_until_first_complete scope --- starlette/concurrency.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index f1dfb0a13..e89d1e047 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -14,13 +14,14 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - async def run(handler: typing.Callable[[], typing.Coroutine]) -> None: - await handler() - tg.cancel_scope.cancel() + async with anyio.create_task_group() as task_group: - async with anyio.create_task_group() as tg: - for handler, kwargs in args: - tg.start_soon(run, functools.partial(handler, **kwargs)) + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) async def run_in_threadpool(