Skip to content

Commit

Permalink
Fix streaming issue and upgrade Starlette
Browse files Browse the repository at this point in the history
After a lot of debugging, I dug out, that streaming wasn't working
because of BaseHTTPMiddleware, which it seems collects all the stream
into memory and then returns it all at once with the response.

That means, if you stream a lot of data, request will not give any
answer until all data is collected into memory. This can take time and
can result in read timeout or can simply run out of memory.

Streaming was the main reason, why I chose Starlette, and one and most
important thing didn't worked. Not good. But at least I found how to fix
it.

Related issues:

encode/starlette#1012
encode/starlette#472
  • Loading branch information
sirex committed Dec 10, 2020
1 parent a32a637 commit b71aba2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 61 deletions.
2 changes: 1 addition & 1 deletion requirements-dev.txt
Expand Up @@ -115,7 +115,7 @@ sqlalchemy-utils==0.36.0 # via -r requirements-dev.in
sqlalchemy==1.3.11 # via -r requirements.in, alembic, sqlalchemy-utils
sqlparse==0.3.0 # via -r requirements-dev.in
sshpubkeys==3.1.0 # via moto
starlette==0.13.0 # via -r requirements.in
starlette==0.14.1 # via -r requirements.in
texttable==1.6.2 # via -r requirements.in
toposort==1.5 # via -r requirements.in
tornado==6.0.3 # via livereload, sphinx-autobuild
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -44,7 +44,7 @@ s3transfer==0.3.3 # via boto3
setuptools-scm==3.3.3 # via -r requirements.in
six==1.13.0 # via cryptography, multipledispatch, python-dateutil, python-multipart
sqlalchemy==1.3.11 # via -r requirements.in, alembic
starlette==0.13.0 # via -r requirements.in
starlette==0.14.1 # via -r requirements.in
texttable==1.6.2 # via -r requirements.in
toposort==1.5 # via -r requirements.in
tqdm==4.40.2 # via -r requirements.in
Expand Down
2 changes: 1 addition & 1 deletion spinta/api.py
Expand Up @@ -172,7 +172,7 @@ def init(context: Context):
]

middleware = [
Middleware(ContextMiddleware)
Middleware(ContextMiddleware, context=context)
]

exception_handlers = {
Expand Down
72 changes: 14 additions & 58 deletions spinta/middlewares.py
@@ -1,13 +1,12 @@
import asyncio
import typing
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
from starlette.types import Send

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Receive, Scope, Send
from spinta.components import Context


class ContextMiddleware(BaseHTTPMiddleware):
class ContextMiddleware:
"""Adds `request.state.context`.
There is a global `context`, where all heavy things are preloaded as
Expand All @@ -18,58 +17,15 @@ class ContextMiddleware(BaseHTTPMiddleware):
modified in each request without effecting global context.
"""

def __init__(self, app: ASGIApp, *, context: Context) -> None:
self.app = app
self.context = context

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
self._orig_send = send
if scope['type'] == 'http':
with scope['app'].state.context.fork('request') as context:
if scope["type"] in ["http", "websocket"]:
with self.context.fork('request') as context:
scope.setdefault('state', {})
scope['state']['context'] = context
await super().__call__(scope, receive, send)
else:
await super().__call__(scope, receive, send)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)

# TODO: Temporary fix for https://github.com/encode/starlette/issues/472
async def call_next(self, request: Request) -> Response:
loop = asyncio.get_event_loop()
queue = asyncio.Queue() # type: asyncio.Queue

scope = request.scope
receive = request.receive
send = queue.put

async def coro() -> None:
try:
await self.app(scope, receive, send)
finally:
await queue.put(None)

task = loop.create_task(coro())
message = await queue.get()
if message is None:
task.result()
raise RuntimeError("No response returned.")

if "http.response.template" in scope.get("extensions", {}):
if message["type"] == "http.response.template":
await self._orig_send(message)
message = await queue.get()

assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
while True:
message = await queue.get()
if message is None:
break
assert message["type"] == "http.response.body"
yield message["body"]
task.result()

response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
else:
await self.app(scope, receive, send)

0 comments on commit b71aba2

Please sign in to comment.