From 1a56a1facc2a75e4125081b2c2ff18ccdddbf4a6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 24 Jun 2021 09:54:07 +0100 Subject: [PATCH] for lifespan task verification, use native task identity rather than anyio.abc.TaskInfo equality https://github.com/agronholm/anyio/issues/324 --- tests/test_testclient.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 53c1bca7a..3c9038ea3 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,13 +1,22 @@ +import asyncio import itertools +import sys import anyio import pytest +import sniffio +import trio.lowlevel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse from starlette.websockets import WebSocket, WebSocketDisconnect +if sys.version_info >= (3, 7): + from asyncio import current_task as asyncio_current_task # pragma: no cover +else: + asyncio_current_task = asyncio.Task.current_task # pragma: no cover + mock_service = Starlette() @@ -28,17 +37,32 @@ def get_identity(counter): return token +def current_task(): + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() + + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover + + def create_app(test_client_factory, counter=itertools.count()): app = Starlette() @app.on_event("startup") async def get_startup_thread(): - app.startup_task = anyio.get_current_task().id + app.startup_task = current_task() app.startup_loop = get_identity(counter) @app.on_event("shutdown") async def get_shutdown_thread(): - app.shutdown_task = anyio.get_current_task().id + app.shutdown_task = current_task() app.shutdown_loop = get_identity(counter) @app.route("/") @@ -93,7 +117,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam # lifespan events run in the same task, this is important because a task # group must be entered and exited in the same task. - assert app.startup_task == app.shutdown_task + assert app.startup_task is app.shutdown_task # outside the TestClient context, new requests continue to spawn in new # eventloops in new threads @@ -112,12 +136,10 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam assert app.shutdown_loop == 3 # lifespan events still run in the same task, with the context but... - assert app.startup_task == app.shutdown_task + assert app.startup_task is app.shutdown_task - if anyio_backend_name != "asyncio": - # https://github.com/agronholm/anyio/issues/324 - # ... the second TestClient context creates a new lifespan task. - assert first_task != app.startup_task + # ... the second TestClient context creates a new lifespan task. + assert first_task is not app.startup_task def test_error_on_startup(test_client_factory):