forked from encode/starlette
-
Notifications
You must be signed in to change notification settings - Fork 0
/
websockets.py
156 lines (131 loc) Β· 5.59 KB
/
websockets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import enum
import json
import typing
from starlette.requests import HTTPConnection
from starlette.types import Message, Receive, Scope, Send
class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
self.code = code
class WebSocket(HTTPConnection):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
super().__init__(scope)
assert scope["type"] == "websocket"
self._receive = receive
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
assert message_type == "websocket.connect"
self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
assert message_type in {"websocket.receive", "websocket.disconnect"}
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
else:
raise RuntimeError(
'Cannot call "receive" once a disconnect message has been received.'
)
async def send(self, message: Message) -> None:
"""
Send ASGI websocket messages, ensuring valid state transitions.
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
assert message_type in {"websocket.accept", "websocket.close"}
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
assert message_type in {"websocket.send", "websocket.close"}
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(
self,
subprotocol: str = None,
headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
) -> None:
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":
raise WebSocketDisconnect(message["code"])
async def receive_text(self) -> str:
assert self.application_state == WebSocketState.CONNECTED
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]
async def receive_bytes(self) -> bytes:
assert self.application_state == WebSocketState.CONNECTED
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]
async def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
assert self.application_state == WebSocketState.CONNECTED
message = await self.receive()
self._raise_on_disconnect(message)
if mode == "text":
text = message["text"]
else:
text = message["bytes"].decode("utf-8")
return json.loads(text)
async def iter_text(self) -> typing.AsyncIterator[str]:
try:
while True:
yield await self.receive_text()
except WebSocketDisconnect:
pass
async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
try:
while True:
yield await self.receive_bytes()
except WebSocketDisconnect:
pass
async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
try:
while True:
yield await self.receive_json()
except WebSocketDisconnect:
pass
async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})
async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
async def send_json(self, data: typing.Any, mode: str = "text") -> None:
assert mode in ["text", "binary"]
text = json.dumps(data)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
self.code = code
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})