From f4a7f461ba7c9e6a7d383bab598bfd39b0c6acae Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 2 Nov 2022 08:42:01 +0100 Subject: [PATCH 1/7] Add type annotation to `wsproto_impl.py` --- setup.cfg | 1 + uvicorn/protocols/websockets/wsproto_impl.py | 83 +++++++++++++------- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/setup.cfg b/setup.cfg index e6a1b432d..5358b1335 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ files = uvicorn/protocols/http/__init__.py, uvicorn/protocols/websockets/__init__.py, uvicorn/protocols/websockets/websockets_impl.py, + uvicorn/protocols/websockets/wsproto_impl.py, uvicorn/protocols/http/h11_impl.py, uvicorn/protocols/http/httptools_impl.py, tests/middleware/test_wsgi.py, diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 3241a1551..e978da5b7 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -1,5 +1,6 @@ import asyncio import logging +import typing from urllib.parse import unquote import h11 @@ -9,6 +10,7 @@ from wsproto.extensions import PerMessageDeflate from wsproto.utilities import RemoteProtocolError +from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.utils import ( get_local_addr, @@ -16,10 +18,31 @@ get_remote_addr, is_ssl, ) +from uvicorn.server import ServerState + +if typing.TYPE_CHECKING: + from asgiref.typing import ( + ASGISendEvent, + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketScope, + ) + + WebSocketEvent = typing.Union[ + "WebSocketReceiveEvent", + "WebSocketDisconnectEvent", + "WebSocketConnectEvent", + ] class WSProtocol(asyncio.Protocol): - def __init__(self, config, server_state, _loop=None): + def __init__( + self, + config: Config, + server_state: ServerState, + _loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ) -> None: if not config.loaded: config.load() @@ -35,14 +58,13 @@ def __init__(self, config, server_state, _loop=None): self.default_headers = server_state.default_headers # Connection state - self.transport = None - self.server = None - self.client = None - self.scheme = None + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: typing.Optional[typing.Tuple[str, int]] = None + self.client: typing.Optional[typing.Tuple[str, int]] = None + self.scheme: typing.Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state - self.connect_event = None - self.queue = asyncio.Queue() + self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() self.handshake_complete = False self.close_sent = False @@ -58,7 +80,9 @@ def __init__(self, config, server_state, _loop=None): # Protocol interface - def connection_made(self, transport): + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: self.connections.add(self) self.transport = transport self.server = get_local_addr(transport) @@ -66,35 +90,36 @@ def connection_made(self, transport): self.scheme = "wss" if is_ssl(transport) else "ws" if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) - def connection_lost(self, exc): + def connection_lost(self, exc: typing.Optional[Exception]) -> None: code = 1005 if self.handshake_complete else 1006 self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: - prefix = "%s:%d - " % tuple(self.client) if self.client else "" + prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) self.handshake_complete = True if exc is None: self.transport.close() - def eof_received(self): + def eof_received(self) -> None: pass - def data_received(self, data): + def data_received(self, data: bytes) -> None: try: self.conn.receive_data(data) except RemoteProtocolError as err: - self.transport.write(self.conn.send(err.event_hint)) + # TODO: Remove `type: ignore` when wsproto fixes the type annotation. + self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501 self.transport.close() else: self.handle_events() - def handle_events(self): + def handle_events(self) -> None: for event in self.conn.events(): if isinstance(event, events.Request): self.handle_connect(event) @@ -107,19 +132,19 @@ def handle_events(self): elif isinstance(event, events.Ping): self.handle_ping(event) - def pause_writing(self): + def pause_writing(self) -> None: """ Called by the transport when the write buffer exceeds the high water mark. """ self.writable.clear() - def resume_writing(self): + def resume_writing(self) -> None: """ Called by the transport when the write buffer drops below the low water mark. """ self.writable.set() - def shutdown(self): + def shutdown(self) -> None: if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) output = self.conn.send(wsproto.events.CloseConnection(code=1012)) @@ -128,17 +153,16 @@ def shutdown(self): self.send_500_response() self.transport.close() - def on_task_complete(self, task): + def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task) # Event handlers - def handle_connect(self, event): - self.connect_event = event + def handle_connect(self, event: events.Request) -> None: headers = [(b"host", event.host.encode())] headers += [(key.lower(), value) for key, value in event.extra_headers] raw_path, _, query_string = event.target.partition("?") - self.scope = { + self.scope: "WebSocketScope" = { "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, "http_version": "1.1", @@ -151,6 +175,7 @@ def handle_connect(self, event): "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.subprotocols, + "extensions": None, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) @@ -166,7 +191,7 @@ def handle_text(self, event): self.read_paused = True self.transport.pause_reading() - def handle_bytes(self, event): + def handle_bytes(self, event: events.BytesMessage) -> None: self.bytes += event.data # todo: we may want to guard the size of self.bytes and self.text if event.message_finished: @@ -176,16 +201,16 @@ def handle_bytes(self, event): self.read_paused = True self.transport.pause_reading() - def handle_close(self, event): + def handle_close(self, event: events.CloseConnection) -> None: if self.conn.state == ConnectionState.REMOTE_CLOSING: self.transport.write(self.conn.send(event.response())) self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code}) self.transport.close() - def handle_ping(self, event): + def handle_ping(self, event: events.Ping) -> None: self.transport.write(self.conn.send(event.response())) - def send_500_response(self): + def send_500_response(self) -> None: headers = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), @@ -203,7 +228,7 @@ def send_500_response(self): output += self.conn.send(msg) self.transport.write(output) - async def run_asgi(self): + async def run_asgi(self) -> None: try: result = await self.app(self.scope, self.receive, self.send) except BaseException: @@ -222,7 +247,7 @@ async def run_asgi(self): self.logger.error(msg, result) self.transport.close() - async def send(self, message): + async def send(self, message: "ASGISendEvent") -> None: await self.writable.wait() message_type = message["type"] @@ -303,7 +328,7 @@ async def send(self, message): msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." raise RuntimeError(msg % message_type) - async def receive(self): + async def receive(self) -> "WebSocketEvent": message = await self.queue.get() if self.read_paused and self.queue.empty(): self.read_paused = False From 6a2dfec50d15f5d7b77f07f2225b1b5b7d0b0a73 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 13:47:09 +0100 Subject: [PATCH 2/7] add typpeddict ignore --- uvicorn/protocols/websockets/wsproto_impl.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index e978da5b7..15ce6beea 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -182,10 +182,14 @@ def handle_connect(self, event: events.Request) -> None: task.add_done_callback(self.on_task_complete) self.tasks.add(task) - def handle_text(self, event): + def handle_text(self, event: events.TextMessage) -> None: self.text += event.data if event.message_finished: - self.queue.put_nowait({"type": "websocket.receive", "text": self.text}) + msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + "type": "websocket.receive", + "text": self.text, + } + self.queue.put_nowait(msg) self.text = "" if not self.read_paused: self.read_paused = True @@ -195,7 +199,11 @@ def handle_bytes(self, event: events.BytesMessage) -> None: self.bytes += event.data # todo: we may want to guard the size of self.bytes and self.text if event.message_finished: - self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes}) + msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + "type": "websocket.receive", + "bytes": self.bytes, + } + self.queue.put_nowait(msg) self.bytes = b"" if not self.read_paused: self.read_paused = True From b7f49013c8848c03819d835ad53442a52cac1624 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 13:52:41 +0100 Subject: [PATCH 3/7] add websocket accept type --- uvicorn/protocols/websockets/wsproto_impl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 15ce6beea..348e7342d 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -23,6 +23,7 @@ if typing.TYPE_CHECKING: from asgiref.typing import ( ASGISendEvent, + WebSocketAcceptEvent, WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, @@ -262,6 +263,7 @@ async def send(self, message: "ASGISendEvent") -> None: if not self.handshake_complete: if message_type == "websocket.accept": + message = typing.cast("WebSocketAcceptEvent", message) self.logger.info( '%s - "WebSocket %s" [accepted]', self.scope["client"], From b5c367d280c69fd3df2c59f192f02053ae568b8f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 13:55:18 +0100 Subject: [PATCH 4/7] add list extension --- uvicorn/protocols/websockets/wsproto_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 348e7342d..0045e7954 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -7,7 +7,7 @@ import wsproto from wsproto import ConnectionType, events from wsproto.connection import ConnectionState -from wsproto.extensions import PerMessageDeflate +from wsproto.extensions import PerMessageDeflate, Extension from wsproto.utilities import RemoteProtocolError from uvicorn.config import Config @@ -271,7 +271,7 @@ async def send(self, message: "ASGISendEvent") -> None: ) subprotocol = message.get("subprotocol") extra_headers = self.default_headers + list(message.get("headers", [])) - extensions = [] + extensions: typing.List[Extension] = [] if self.config.ws_per_message_deflate: extensions.append(PerMessageDeflate()) if not self.transport.is_closing(): From 9cee0b4269792937a9753210680754ab61fcecf7 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 14:01:17 +0100 Subject: [PATCH 5/7] create event variable --- uvicorn/protocols/websockets/wsproto_impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 0045e7954..831d799b7 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -7,7 +7,7 @@ import wsproto from wsproto import ConnectionType, events from wsproto.connection import ConnectionState -from wsproto.extensions import PerMessageDeflate, Extension +from wsproto.extensions import Extension, PerMessageDeflate from wsproto.utilities import RemoteProtocolError from uvicorn.config import Config @@ -294,8 +294,8 @@ async def send(self, message: "ASGISendEvent") -> None: ) self.handshake_complete = True self.close_sent = True - msg = events.RejectConnection(status_code=403, headers=[]) - output = self.conn.send(msg) + event = events.RejectConnection(status_code=403, headers=[]) + output = self.conn.send(event) self.transport.write(output) self.transport.close() From 94728bd51504deec8757b74ac126750568354f21 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 14:07:54 +0100 Subject: [PATCH 6/7] add send and close event --- uvicorn/protocols/websockets/wsproto_impl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 831d799b7..0fe9e48e3 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -24,10 +24,12 @@ from asgiref.typing import ( ASGISendEvent, WebSocketAcceptEvent, + WebSocketCloseEvent, WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, WebSocketScope, + WebSocketSendEvent, ) WebSocketEvent = typing.Union[ @@ -308,14 +310,18 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send": + message = typing.cast("WebSocketSendEvent", message) bytes_data = message.get("bytes") text_data = message.get("text") data = text_data if bytes_data is None else bytes_data - output = self.conn.send(wsproto.events.Message(data=data)) + output = self.conn.send( + wsproto.events.Message(data=data) # type: ignore[type-var] + ) if not self.transport.is_closing(): self.transport.write(output) elif message_type == "websocket.close": + message = typing.cast("WebSocketCloseEvent", message) self.close_sent = True code = message.get("code", 1000) reason = message.get("reason", "") or "" From 15d73a3748e03c9ca6f3efeb0a72e38adb2e64a1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 20 Nov 2022 14:21:33 +0100 Subject: [PATCH 7/7] add literal --- uvicorn/protocols/websockets/wsproto_impl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 0fe9e48e3..f2677e004 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -1,5 +1,6 @@ import asyncio import logging +import sys import typing from urllib.parse import unquote @@ -38,6 +39,11 @@ "WebSocketConnectEvent", ] +if sys.version_info < (3, 8): # pragma: py-gte-38 + from typing_extensions import Literal +else: # pragma: py-lt-38 + from typing import Literal + class WSProtocol(asyncio.Protocol): def __init__( @@ -64,7 +70,7 @@ def __init__( self.transport: asyncio.Transport = None # type: ignore[assignment] self.server: typing.Optional[typing.Tuple[str, int]] = None self.client: typing.Optional[typing.Tuple[str, int]] = None - self.scheme: typing.Literal["wss", "ws"] = None # type: ignore[assignment] + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue()