-
-
Notifications
You must be signed in to change notification settings - Fork 855
/
base.py
126 lines (99 loc) Β· 4.43 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import typing
import anyio
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")
async def _call_and_cancel(
func: typing.Callable[[], typing.Awaitable[T]], cancel_scope: anyio.CancelScope
) -> T:
result = await func()
cancel_scope.cancel()
return result
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
app_disconnected = anyio.Event()
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
async def receive_or_disconnect() -> Message:
if app_disconnected.is_set():
return {"type": "http.disconnect"}
async with anyio.create_task_group() as task_group:
# app_disconnected is set, cancel parent scope to cancel
# request.receive
task_group.start_soon(
_call_and_cancel,
app_disconnected.wait,
task_group.cancel_scope,
)
# if request.receive returns a message, cancel the task_group to
# exit its block
message = await _call_and_cancel(
request.receive, task_group.cancel_scope
)
if app_disconnected.is_set():
return {"type": "http.disconnect"}
return message
async def close_recv_stream_on_disconnect():
await app_disconnected.wait()
recv_stream.close()
async def send_no_error(message: Message):
try:
await send_stream.send(message)
except anyio.BrokenResourceError:
# recv_stream has been closed, i.e. app_disconnected has been set.
return
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
await self.app(scope, receive_or_disconnect, send_no_error)
except Exception as exc:
app_exc = exc
task_group.start_soon(close_recv_stream_on_disconnect)
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break
if app_exc is not None:
raise app_exc
response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
app_disconnected.set()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover