Skip to content

Commit

Permalink
Add type hint to websockets_impl.py (#1308)
Browse files Browse the repository at this point in the history
* Add type hint to websockets_impl

* Fix mypy issues

* Fix flake8 issues

* Add Literal import

* Add extra headers type

* readd subprotocol

* remove asgiref types from runtime

* Comply with comments
  • Loading branch information
Kludex committed Oct 29, 2022
1 parent 7424aaf commit 0378913
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 33 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ files =
uvicorn/protocols/__init__.py,
uvicorn/protocols/http/__init__.py,
uvicorn/protocols/websockets/__init__.py,
uvicorn/protocols/websockets/websockets_impl.py,
uvicorn/protocols/http/h11_impl.py,
uvicorn/protocols/http/httptools_impl.py,
tests/middleware/test_wsgi.py,
Expand Down
119 changes: 86 additions & 33 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,63 @@
import asyncio
import http
import logging
import sys
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union, cast
from urllib.parse import unquote

import websockets
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol

from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl
from uvicorn.server import ServerState

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal

if TYPE_CHECKING:
from asgiref.typing import (
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
WebSocketConnectEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketScope,
WebSocketSendEvent,
)


class Server:
closing = False

def register(self, ws):
def register(self, ws: WebSocketServerProtocol) -> None:
pass

def unregister(self, ws):
def unregister(self, ws: WebSocketServerProtocol) -> None:
pass

def is_serving(self):
def is_serving(self) -> bool:
return not self.closing


class WebSocketProtocol(websockets.WebSocketServerProtocol):
def __init__(self, config, server_state, _loop=None):
class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: List[Tuple[str, str]]

def __init__(
self,
config: Config,
server_state: ServerState,
_loop: Optional[asyncio.AbstractEventLoop] = None,
):
if not config.loaded:
config.load()

Expand All @@ -38,30 +71,30 @@ def __init__(self, config, server_state, _loop=None):
self.tasks = server_state.tasks

# Connection state
self.transport = None
self.server = None
self.client = None
self.scheme = None
self.transport: asyncio.Transport = None # type: ignore[assignment]
self.server: Optional[Tuple[str, int]] = None
self.client: Optional[Tuple[str, int]] = None
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]

# Connection events
self.scope = None
self.scope: WebSocketScope = None # type: ignore[assignment]
self.handshake_started_event = asyncio.Event()
self.handshake_completed_event = asyncio.Event()
self.closed_event = asyncio.Event()
self.initial_response = None
self.initial_response: Optional[HTTPResponse] = None
self.connect_sent = False
self.accepted_subprotocol = None
self.transfer_data_task = None
self.accepted_subprotocol: Optional[Subprotocol] = None
self.transfer_data_task: asyncio.Task = None # type: ignore[assignment]

self.ws_server = Server()
self.ws_server: Server = Server() # type: ignore[assignment]

extensions = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())

super().__init__(
ws_handler=self.ws_handler,
ws_server=self.ws_server,
ws_server=self.ws_server, # type: ignore[arg-type]
max_size=self.config.ws_max_size,
ping_interval=self.config.ws_ping_interval,
ping_timeout=self.config.ws_ping_timeout,
Expand All @@ -70,39 +103,43 @@ def __init__(self, config, server_state, _loop=None):
extra_headers=[],
)

def connection_made(self, transport):
def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

super().connection_made(transport)

def connection_lost(self, exc):
def connection_lost(self, exc: Optional[Exception]) -> None:
self.connections.remove(self)

if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
prefix = "%s:%d - " % self.client if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

self.handshake_completed_event.set()
super().connection_lost(exc)
if exc is None:
self.transport.close()

def shutdown(self):
def shutdown(self) -> None:
self.ws_server.closing = True
self.transport.close()

def on_task_complete(self, task):
def on_task_complete(self, task: asyncio.Task) -> None:
self.tasks.discard(task)

async def process_request(self, path, headers):
async def process_request(
self, path: str, headers: Headers
) -> Optional[HTTPResponse]:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
Expand All @@ -124,7 +161,7 @@ async def process_request(self, path, headers):
for name, value in headers.raw_items()
]

self.scope = {
self.scope = { # type: ignore[typeddict-item]
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.3"},
"http_version": "1.1",
Expand All @@ -144,14 +181,16 @@ async def process_request(self, path, headers):
await self.handshake_started_event.wait()
return self.initial_response

def process_subprotocol(self, headers, available_subprotocols):
def process_subprotocol(
self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
) -> Optional[Subprotocol]:
"""
We override the standard 'process_subprotocol' behavior here so that
we return whatever subprotocol is sent in the 'accept' message.
"""
return self.accepted_subprotocol

def send_500_response(self):
def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
Expand All @@ -166,7 +205,9 @@ def send_500_response(self):
# itself (see https://github.com/encode/uvicorn/issues/920)
self.handshake_started_event.set()

async def ws_handler(self, protocol, path):
async def ws_handler( # type: ignore[override]
self, protocol: WebSocketServerProtocol, path: str
) -> Any:
"""
This is the main handler function for the 'websockets' implementation
to call into. We just wait for close then return, and instead allow
Expand All @@ -175,7 +216,7 @@ async def ws_handler(self, protocol, path):
self.handshake_completed_event.set()
await self.closed_event.wait()

async def run_asgi(self):
async def run_asgi(self) -> None:
"""
Wrapper around the ASGI callable, handling exceptions and unexpected
termination states.
Expand Down Expand Up @@ -204,18 +245,21 @@ async def run_asgi(self):
await self.handshake_completed_event.wait()
self.transport.close()

async def asgi_send(self, message):
async def asgi_send(self, message: "ASGISendEvent") -> None:
message_type = message["type"]

if not self.handshake_started_event.is_set():
if message_type == "websocket.accept":
message = cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
self.scope["path"],
)
self.initial_response = None
self.accepted_subprotocol = message.get("subprotocol")
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
Expand All @@ -226,6 +270,7 @@ async def asgi_send(self, message):
self.handshake_started_event.set()

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
Expand All @@ -246,12 +291,14 @@ async def asgi_send(self, message):
await self.handshake_completed_event.wait()

if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data)
await self.send(data) # type: ignore[arg-type]

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
Expand All @@ -268,7 +315,11 @@ async def asgi_send(self, message):
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
raise RuntimeError(msg % message_type)

async def asgi_receive(self):
async def asgi_receive(
self,
) -> Union[
"WebSocketDisconnectEvent", "WebSocketConnectEvent", "WebSocketReceiveEvent"
]:
if not self.connect_sent:
self.connect_sent = True
return {"type": "websocket.connect"}
Expand All @@ -283,11 +334,13 @@ async def asgi_receive(self):

try:
data = await self.recv()
except websockets.ConnectionClosed as exc:
except ConnectionClosed as exc:
self.closed_event.set()
return {"type": "websocket.disconnect", "code": exc.code}

msg = {"type": "websocket.receive"}
msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item]
"type": "websocket.receive"
}

if isinstance(data, str):
msg["text"] = data
Expand Down

0 comments on commit 0378913

Please sign in to comment.