Skip to content

Commit

Permalink
Use Starlette(lifetime=...)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Jul 19, 2021
1 parent c4cd269 commit 629e465
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
27 changes: 2 additions & 25 deletions fastapi/applications.py
@@ -1,9 +1,8 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union

from fastapi import routing
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import DependencyCacheKey
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exception_handlers import (
http_exception_handler,
Expand All @@ -30,9 +29,6 @@


class FastAPI(Starlette):

app_lifespan_astack: Union[AsyncExitStack, None]

def __init__(
self,
*,
Expand Down Expand Up @@ -70,26 +66,6 @@ def __init__(
) -> None:
self._debug: bool = debug
self.state: State = State()

on_startup = [] if on_startup is None else list(on_startup)
on_shutdown = [] if on_shutdown is None else list(on_shutdown)

if AsyncExitStack:
async def initialize_app_lifespan_dependency_stack():
self.app_lifespan_astack = AsyncExitStack()
await self.app_lifespan_astack.__aenter__()
on_startup.append(initialize_app_lifespan_dependency_stack)
async def shutdown_app_lifespan_dependency_stack():
await self.app_lifespan_astack.__aexit__(None, None, None)
on_shutdown.append(shutdown_app_lifespan_dependency_stack)
else:
self.app_lifespan_astack = None

self.lifespan_dependencies: Dict[DependencyCacheKey, Any] = {}
def clear_app_lifespan_dependencies():
self.lifespan_dependencies = {}
on_shutdown.append(clear_app_lifespan_dependencies)

self.router: routing.APIRouter = routing.APIRouter(
routes=routes,
dependency_overrides_provider=self,
Expand Down Expand Up @@ -117,6 +93,7 @@ def clear_app_lifespan_dependencies():
[] if middleware is None else list(middleware)
)
self.middleware_stack: ASGIApp = self.build_middleware_stack()

self.title = title
self.description = description
self.version = version
Expand Down
2 changes: 1 addition & 1 deletion fastapi/dependencies/utils.py
Expand Up @@ -544,7 +544,7 @@ async def solve_dependencies(
if sub_dependant.lifespan == "request":
stack = request.scope.get("fastapi_astack")
else: # lifespan == "app"
stack = getattr(request.app, "app_lifespan_astack")
stack = getattr(request.app.router, "lifespan_astack") # type: ignore
if stack is None:
raise RuntimeError(
async_contextmanager_dependencies_error
Expand Down
34 changes: 31 additions & 3 deletions fastapi/routing.py
@@ -1,10 +1,12 @@
import asyncio
import contextlib
import email.message
import enum
import inspect
import json
from typing import (
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Expand All @@ -17,8 +19,9 @@
)

from fastapi import params
from fastapi.concurrency import AsyncExitStack
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.models import Dependant, DependencyCacheKey
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
Expand Down Expand Up @@ -208,7 +211,7 @@ async def app(request: Request) -> Response:
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
lifespan_dependencies=request.app.lifespan_dependencies
lifespan_dependencies=request.app.router.lifespan_dependencies
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
Expand Down Expand Up @@ -255,7 +258,7 @@ async def app(websocket: WebSocket) -> None:
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
lifespan_dependencies=websocket.app.lifespan_dependencies
lifespan_dependencies=websocket.app.router.lifespan_dependencies
)
values, errors, _, _2, _3 = solved_result
if errors:
Expand Down Expand Up @@ -424,6 +427,10 @@ def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]


class APIRouter(routing.Router):

lifespan_astack: Union[AsyncExitStack, None]
lifespan_dependencies: Dict[DependencyCacheKey, Any]

def __init__(
self,
*,
Expand All @@ -440,16 +447,37 @@ def __init__(
route_class: Type[APIRoute] = APIRoute,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
lifespan: Callable[[Any], AsyncGenerator] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
) -> None:
self.lifespan_dependencies = {}

@contextlib.asynccontextmanager
async def dep_stack_cm() -> AsyncGenerator:
if AsyncExitStack:
async with AsyncExitStack() as self.lifespan_astack:
yield
else:
self.lifespan_astack = None
yield
self.lifespan_dependencies = {}

async def lifespan_context(app: Any) -> AsyncGenerator:
async with dep_stack_cm():
await self.startup()
yield
await self.shutdown()

super().__init__(
routes=routes, # type: ignore # in Starlette
redirect_slashes=redirect_slashes,
default=default, # type: ignore # in Starlette
on_startup=on_startup, # type: ignore # in Starlette
on_shutdown=on_shutdown, # type: ignore # in Starlette
lifespan=lifespan_context,
)

if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
Expand Down

0 comments on commit 629e465

Please sign in to comment.