/
base.py
62 lines (47 loc) 路 2.23 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import typing
import anyio
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
class BaseHTTPMiddleware:
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def call_next(request: Request) -> Response:
send_stream, recv_stream = anyio.create_memory_object_stream()
async def coro() -> None:
async with send_stream:
await self.app(scope, request.receive, send_stream.send)
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
except anyio.EndOfStream:
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
yield message.get("body", b"")
response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover