Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor WSMessage to use tagged unions #7319

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiohttp/__init__.py
Expand Up @@ -45,7 +45,7 @@
HttpVersion11,
WebSocketError,
WSCloseCode,
WSMessage,
WSMessageType,
WSMsgType,
)
from .multipart import (
Expand Down Expand Up @@ -152,7 +152,7 @@
"HttpVersion11",
"WSMsgType",
"WSCloseCode",
"WSMessage",
"WSMessageType",
"WebSocketError",
# multipart
"BadContentDispositionHeader",
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/client.py
Expand Up @@ -81,7 +81,7 @@
strip_auth_from_url,
)
from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
from .http_websocket import WSHandshakeError, WSMessageType, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tracing import Trace, TraceConfig
from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL
Expand Down Expand Up @@ -876,7 +876,7 @@ async def _ws_connect(
assert conn_proto is not None
transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
reader: FlowControlDataQueue[WSMessageType] = FlowControlDataQueue(
conn_proto, 2**16, loop=self._loop
)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
Expand Down
31 changes: 15 additions & 16 deletions aiohttp/client_ws.py
Expand Up @@ -11,14 +11,13 @@
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSMessage,
WSMessageClosed,
WSMessageClosing,
WSMsgType,
)
from .http_websocket import WebSocketWriter # WSMessage
from .http_websocket import WebSocketWriter, WSMessageError, WSMessageType
from .streams import EofStream, FlowControlDataQueue
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand All @@ -42,7 +41,7 @@ class ClientWSTimeout:
class ClientWebSocketResponse:
def __init__(
self,
reader: "FlowControlDataQueue[WSMessage]",
reader: "FlowControlDataQueue[WSMessageType]",
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
Expand Down Expand Up @@ -189,7 +188,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting is not None and not self._closed:
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
self._reader.feed_data(WSMessageClosing(), 0)
await self._waiting

if not self._closed:
Expand Down Expand Up @@ -225,23 +224,23 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo
self._response.close()
return True

if msg.type == WSMsgType.CLOSE:
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
else:
return False

async def receive(self, timeout: Optional[float] = None) -> WSMessage:
async def receive(self, timeout: Optional[float] = None) -> WSMessageType:
while True:
if self._waiting is not None:
raise RuntimeError("Concurrent call to receive() is not allowed")

if self._closed:
return WS_CLOSED_MESSAGE
return WSMessageClosed()
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
return WSMessageClosed()

try:
self._waiting = self._loop.create_future()
Expand All @@ -261,23 +260,23 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
return WSMessageClosed()
except ClientError:
self._closed = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
return WSMessageClosed()
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
return WSMessageError(data=exc)

if msg.type == WSMsgType.CLOSE:
if msg.type is WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
if not self._closed and self._autoclose:
Expand Down Expand Up @@ -316,7 +315,7 @@ async def receive_json(
def __aiter__(self) -> "ClientWebSocketResponse":
return self

async def __anext__(self) -> WSMessage:
async def __anext__(self) -> WSMessageType:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
Expand Down
12 changes: 6 additions & 6 deletions aiohttp/http.py
Expand Up @@ -11,14 +11,14 @@
RawResponseMessage,
)
from .http_websocket import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WS_KEY,
WebSocketError,
WebSocketReader,
WebSocketWriter,
WSCloseCode,
WSMessage,
WSMessageClosed,
WSMessageClosing,
WSMessageType,
WSMsgType,
ws_ext_gen,
ws_ext_parse,
Expand All @@ -41,14 +41,14 @@
"RawRequestMessage",
"RawResponseMessage",
# .http_websocket
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WSMessageClosed",
"WSMessageClosing",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"ws_ext_gen",
"ws_ext_parse",
"WSMessage",
"WSMessageType",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
Expand Down
107 changes: 90 additions & 17 deletions aiohttp/http_websocket.py
Expand Up @@ -7,19 +7,24 @@
import re
import sys
import zlib
from dataclasses import dataclass
from enum import IntEnum
from struct import Struct
from typing import (
Any,
Callable,
Generic,
List,
Literal,
NamedTuple,
Optional,
Pattern,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)

from typing_extensions import Final
Expand Down Expand Up @@ -90,22 +95,90 @@ class WSMsgType(IntEnum):
DEFAULT_LIMIT: Final[int] = 2**16


class WSMessage(NamedTuple):
@dataclass
class _WSMessage:
data: object
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
extra: Optional[str] = None

def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.

.. versionadded:: 0.22
"""
@dataclass
class WSMessageContinuation(_WSMessage):
data: bytes
type: Literal[WSMsgType.CONTINUATION] = WSMsgType.CONTINUATION


@dataclass
class WSMessageText(_WSMessage):
data: str
type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT

def json(
self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads
) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


@dataclass
class WSMessageBinary(_WSMessage):
data: bytes
type: Literal[WSMsgType.BINARY] = WSMsgType.BINARY

def json(
self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads
) -> Any:
"""Return parsed JSON data."""
return loads(self.data)


WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
@dataclass
class WSMessagePing(_WSMessage):
data: bytes
type: Literal[WSMsgType.PING] = WSMsgType.PING


@dataclass
class WSMessagePong(_WSMessage):
data: bytes
type: Literal[WSMsgType.PONG] = WSMsgType.PONG


@dataclass
class WSMessageClose(_WSMessage):
data: int
type: Literal[WSMsgType.CLOSE] = WSMsgType.CLOSE


@dataclass
class WSMessageClosing(_WSMessage):
data: None = None
type: Literal[WSMsgType.CLOSING] = WSMsgType.CLOSING


@dataclass
class WSMessageClosed(_WSMessage):
data: None = None
type: Literal[WSMsgType.CLOSED] = WSMsgType.CLOSED


@dataclass
class WSMessageError(_WSMessage):
data: Exception
type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR


WSMessageType = Union[
WSMessageContinuation,
WSMessageText,
WSMessageBinary,
WSMessagePing,
WSMessagePong,
WSMessageClose,
WSMessageClosing,
WSMessageClosed,
WSMessageError,
]


class WebSocketError(Exception):
Expand Down Expand Up @@ -263,7 +336,7 @@ class WSParserState(IntEnum):

class WebSocketReader:
def __init__(
self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True
self, queue: DataQueue[WSMessageType], max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
Expand Down Expand Up @@ -318,25 +391,25 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
msg = WSMessageClose(data=close_code, extra=close_message)
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = WSMessage(WSMsgType.CLOSE, 0, "")
msg = WSMessageClose(data=0, extra="")

self.queue.feed_data(msg, 0)

elif opcode == WSMsgType.PING:
self.queue.feed_data(
WSMessage(WSMsgType.PING, payload, ""), len(payload)
WSMessagePing(data=payload, extra=""), len(payload)
)

elif opcode == WSMsgType.PONG:
self.queue.feed_data(
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
WSMessagePong(data=payload, extra=""), len(payload)
)

elif (
Expand Down Expand Up @@ -410,15 +483,15 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
try:
text = payload_merged.decode("utf-8")
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ""), len(text)
WSMessageText(data=text, extra=""), len(text)
)
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""),
WSMessageBinary(data=payload_merged, extra=""),
len(payload_merged),
)

Expand Down