Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send lifespan.shutdown.failed event on lifespan shutdown error #1205

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 35 additions & 17 deletions starlette/routing.py
Expand Up @@ -541,29 +541,47 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
app = scope.get("app")
await receive()

async def as_aiter(gen: typing.Generator) -> typing.AsyncGenerator:
for val in gen:
yield val

lifespan_gen = (
self.lifespan_context(app)
if inspect.isasyncgenfunction(self.lifespan_context)
else as_aiter(self.lifespan_context(app)) # type: ignore
)

await self._lifespan_iter(lifespan_gen, receive, send)

async def _lifespan_iter(
self, gen: typing.AsyncGenerator, receive: Receive, send: Send
) -> None:
try:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just making lifespan context an asynccontextmanager factory makes all this easier

if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
await send({"type": "lifespan.startup.failed", "message": exc_text})
await gen.asend(None)
except StopAsyncIteration:
raise RuntimeError("Lifespan context never yielded.")
except Exception:
exc_text = traceback.format_exc()
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
await send({"type": "lifespan.startup.complete"})
await receive()

try:
await gen.asend(None)
except StopAsyncIteration:
await send({"type": "lifespan.shutdown.complete"})
except Exception:
exc_text = traceback.format_exc()
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
raise
else:
await gen.aclose()
raise RuntimeError("Lifespan context yielded multiple times.")

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Expand Down
9 changes: 8 additions & 1 deletion starlette/testclient.py
Expand Up @@ -528,4 +528,11 @@ async def wait_shutdown(self) -> None:
message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] == "lifespan.shutdown.complete"
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
message = await self.stream_send.receive()
if message is None:
self.task.result()
102 changes: 101 additions & 1 deletion tests/test_routing.py
Expand Up @@ -528,6 +528,31 @@ async def run_shutdown():
assert shutdown_complete


def test_lifespan_context_async():
startup_complete = False
shutdown_complete = False

async def hello_world(request):
return PlainTextResponse("hello, world")

async def lifespan(app):
nonlocal startup_complete, shutdown_complete
startup_complete = True
yield
shutdown_complete = True

app = Router(lifespan=lifespan, routes=[Route("/", hello_world)])

assert not startup_complete
assert not shutdown_complete
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete


def test_lifespan_sync():
startup_complete = False
shutdown_complete = False
Expand Down Expand Up @@ -559,6 +584,31 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_context_sync():
startup_complete = False
shutdown_complete = False

def hello_world(request):
return PlainTextResponse("hello, world")

def lifespan(app):
nonlocal startup_complete, shutdown_complete
startup_complete = True
yield
shutdown_complete = True

app = Router(lifespan=lifespan, routes=[Route("/", hello_world)])

assert not startup_complete
assert not shutdown_complete
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete


def test_raise_on_startup():
def run_startup():
raise RuntimeError()
Expand All @@ -585,11 +635,61 @@ def test_raise_on_shutdown():
def run_shutdown():
raise RuntimeError()

app = Router(on_shutdown=[run_shutdown])
router = Router(on_shutdown=[run_shutdown])

async def app(scope, receive, send):
async def _send(message):
nonlocal shutdown_failed
if message["type"] == "lifespan.shutdown.failed":
shutdown_failed = True
return await send(message)

await router(scope, receive, _send)

shutdown_failed = False
with pytest.raises(RuntimeError):
with TestClient(app):
pass # pragma: nocover
assert shutdown_failed


def test_lifespan_yield_too_many():
startup_complete = False
shutdown_complete = False

async def hello_world(request):
return PlainTextResponse("hello, world")

async def yield_too_many(app):
nonlocal startup_complete, shutdown_complete
startup_complete = True
yield
yield
shutdown_complete = True # pragma: nocover

app = Router(lifespan=yield_too_many, routes=[Route("/", hello_world)])

assert not startup_complete
assert not shutdown_complete
with pytest.raises(RuntimeError, match="yielded multiple times"):
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")

assert startup_complete
assert not shutdown_complete


def test_lifespan_yield_too_few():
async def yield_too_few(app):
for _ in []:
yield _ # pragma: nocover

app = Router(lifespan=yield_too_few)
with pytest.raises(RuntimeError, match="never yielded"):
with TestClient(app):
pass # pragma: nocover


class AsyncEndpointClassMethod:
Expand Down