diff --git a/requirements-dev.txt b/requirements-dev.txt index 2557a1343..eddcc60c4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -115,7 +115,7 @@ sqlalchemy-utils==0.36.0 # via -r requirements-dev.in sqlalchemy==1.3.11 # via -r requirements.in, alembic, sqlalchemy-utils sqlparse==0.3.0 # via -r requirements-dev.in sshpubkeys==3.1.0 # via moto -starlette==0.13.0 # via -r requirements.in +starlette==0.14.1 # via -r requirements.in texttable==1.6.2 # via -r requirements.in toposort==1.5 # via -r requirements.in tornado==6.0.3 # via livereload, sphinx-autobuild diff --git a/requirements.txt b/requirements.txt index ed34e4062..9d44e94bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ s3transfer==0.3.3 # via boto3 setuptools-scm==3.3.3 # via -r requirements.in six==1.13.0 # via cryptography, multipledispatch, python-dateutil, python-multipart sqlalchemy==1.3.11 # via -r requirements.in, alembic -starlette==0.13.0 # via -r requirements.in +starlette==0.14.1 # via -r requirements.in texttable==1.6.2 # via -r requirements.in toposort==1.5 # via -r requirements.in tqdm==4.40.2 # via -r requirements.in diff --git a/spinta/api.py b/spinta/api.py index 96f9143cf..fdc15f25b 100644 --- a/spinta/api.py +++ b/spinta/api.py @@ -172,7 +172,7 @@ def init(context: Context): ] middleware = [ - Middleware(ContextMiddleware) + Middleware(ContextMiddleware, context=context) ] exception_handlers = { diff --git a/spinta/middlewares.py b/spinta/middlewares.py index 4779c6f64..6811ac895 100644 --- a/spinta/middlewares.py +++ b/spinta/middlewares.py @@ -1,13 +1,12 @@ -import asyncio -import typing +from starlette.types import ASGIApp +from starlette.types import Receive +from starlette.types import Scope +from starlette.types import Send -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request -from starlette.responses import Response, StreamingResponse -from starlette.types import Receive, Scope, Send +from spinta.components import Context -class ContextMiddleware(BaseHTTPMiddleware): +class ContextMiddleware: """Adds `request.state.context`. There is a global `context`, where all heavy things are preloaded as @@ -18,58 +17,15 @@ class ContextMiddleware(BaseHTTPMiddleware): modified in each request without effecting global context. """ + def __init__(self, app: ASGIApp, *, context: Context) -> None: + self.app = app + self.context = context + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - self._orig_send = send - if scope['type'] == 'http': - with scope['app'].state.context.fork('request') as context: + if scope["type"] in ["http", "websocket"]: + with self.context.fork('request') as context: scope.setdefault('state', {}) scope['state']['context'] = context - await super().__call__(scope, receive, send) - else: - await super().__call__(scope, receive, send) - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - return await call_next(request) - - # TODO: Temporary fix for https://github.com/encode/starlette/issues/472 - async def call_next(self, request: Request) -> Response: - loop = asyncio.get_event_loop() - queue = asyncio.Queue() # type: asyncio.Queue - - scope = request.scope - receive = request.receive - send = queue.put - - async def coro() -> None: - try: await self.app(scope, receive, send) - finally: - await queue.put(None) - - task = loop.create_task(coro()) - message = await queue.get() - if message is None: - task.result() - raise RuntimeError("No response returned.") - - if "http.response.template" in scope.get("extensions", {}): - if message["type"] == "http.response.template": - await self._orig_send(message) - message = await queue.get() - - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: - message = await queue.get() - if message is None: - break - assert message["type"] == "http.response.body" - yield message["body"] - task.result() - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response + else: + await self.app(scope, receive, send)