/
debug.py
122 lines (100 loc) · 3.47 KB
/
debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import html
import traceback
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
ASGISendEvent,
HTTPResponseBodyEvent,
HTTPResponseStartEvent,
WWWScope,
)
class HTMLResponse:
def __init__(self, content: str, status_code: int):
self.content = content
self.status_code = status_code
async def __call__(
self,
scope: "WWWScope",
receive: "ASGIReceiveCallable",
send: "ASGISendCallable",
) -> None:
response_start: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": self.status_code,
"headers": [(b"content-type", b"text/html; charset=utf-8")],
}
await send(response_start)
response_body: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": self.content.encode("utf-8"),
"more_body": False,
}
await send(response_body)
class PlainTextResponse:
def __init__(self, content: str, status_code: int):
self.content = content
self.status_code = status_code
async def __call__(
self,
scope: "WWWScope",
receive: "ASGIReceiveCallable",
send: "ASGISendCallable",
) -> None:
response_start: "HTTPResponseStartEvent" = {
"type": "http.response.start",
"status": self.status_code,
"headers": [(b"content-type", b"text/plain; charset=utf-8")],
}
await send(response_start)
response_body: "HTTPResponseBodyEvent" = {
"type": "http.response.body",
"body": self.content.encode("utf-8"),
"more_body": False,
}
await send(response_body)
def get_accept_header(scope: "WWWScope") -> str:
accept = "*/*"
for key, value in scope.get("headers", []):
if key == b"accept":
accept = value.decode("ascii")
break
return accept
class DebugMiddleware:
def __init__(self, app: "ASGI3Application"):
self.app = app
async def __call__(
self,
scope: "WWWScope",
receive: "ASGIReceiveCallable",
send: "ASGISendCallable",
) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
response_started = False
async def inner_send(message: "ASGISendEvent") -> None:
nonlocal response_started, send
if message["type"] == "http.response.start":
response_started = True
await send(message)
try:
await self.app(scope, receive, inner_send)
except BaseException as exc:
if response_started:
raise exc from None
accept = get_accept_header(scope)
response: Union[HTMLResponse, PlainTextResponse]
if "text/html" in accept:
exc_html = html.escape(traceback.format_exc())
content = (
"<html><body><h1>500 Server Error</h1><pre>%s</pre></body></html>"
% exc_html
)
response = HTMLResponse(content, status_code=500)
else:
content = traceback.format_exc()
response = PlainTextResponse(content, status_code=500)
await response(scope, receive, send)
raise exc from None