Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compatibility with websockets 11.0. #2609

Merged
merged 6 commits into from Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 18 additions & 7 deletions sanic/server/protocols/websocket_protocol.py
@@ -1,7 +1,13 @@
from typing import TYPE_CHECKING, Optional, Sequence, cast

from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection

try: # websockets < 11.0
from websockets.connection import State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import State # type: ignore
from websockets.server import ServerProtocol # type: ignore

from websockets.typing import Subprotocol

from sanic.exceptions import ServerError
Expand All @@ -15,6 +21,11 @@
from websockets import http11


OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED


class WebSocketProtocol(HttpProtocol):
__slots__ = (
"websocket",
Expand Down Expand Up @@ -74,7 +85,7 @@ def close_if_idle(self):
# Called by Sanic Server when shutting down
# If we've upgraded to websocket, shut it down
if self.websocket is not None:
if self.websocket.connection.state in (CLOSING, CLOSED):
if self.websocket.ws_proto.state in (CLOSING, CLOSED):
return True
elif self.websocket.loop is not None:
self.websocket.loop.create_task(self.websocket.close(1001))
Expand All @@ -90,7 +101,7 @@ async def websocket_handshake(
try:
if subprotocols is not None:
# subprotocols can be a set or frozenset,
# but ServerConnection needs a list
# but ServerProtocol needs a list
subprotocols = cast(
Optional[Sequence[Subprotocol]],
list(
Expand All @@ -100,13 +111,13 @@ async def websocket_handshake(
]
),
)
ws_conn = ServerConnection(
ws_proto = ServerProtocol(
max_size=self.websocket_max_size,
subprotocols=subprotocols,
state=OPEN,
logger=logger,
)
resp: "http11.Response" = ws_conn.accept(request)
resp: "http11.Response" = ws_proto.accept(request)
except Exception:
msg = (
"Failed to open a WebSocket connection.\n"
Expand All @@ -129,7 +140,7 @@ async def websocket_handshake(
else:
raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol(
ws_conn,
ws_proto,
ping_interval=self.websocket_ping_interval,
ping_timeout=self.websocket_ping_timeout,
close_timeout=self.websocket_timeout,
Expand Down
113 changes: 67 additions & 46 deletions sanic/server/websockets/impl.py
Expand Up @@ -12,25 +12,37 @@
Union,
)

from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
)
from websockets.frames import Frame, Opcode
from websockets.server import ServerConnection


try: # websockets < 11.0
from websockets.connection import Event, State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import Event, State # type: ignore
from websockets.server import ServerProtocol # type: ignore

from websockets.typing import Data

from sanic.log import error_logger, logger
from sanic.log import deprecation, error_logger, logger
from sanic.server.protocols.base_protocol import SanicProtocol

from ...exceptions import ServerError, WebsocketClosed
from .frame import WebsocketFrameAssembler


OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED


class WebsocketImplProtocol:
connection: ServerConnection
ws_proto: ServerProtocol
io_proto: Optional[SanicProtocol]
loop: Optional[asyncio.AbstractEventLoop]
max_queue: int
Expand All @@ -56,14 +68,14 @@ class WebsocketImplProtocol:

def __init__(
self,
connection,
ws_proto,
max_queue=None,
ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20,
close_timeout: float = 10,
loop=None,
):
self.connection = connection
self.ws_proto = ws_proto
self.io_proto = None
self.loop = None
self.max_queue = max_queue
Expand All @@ -85,7 +97,16 @@ def __init__(

@property
def subprotocol(self):
return self.connection.subprotocol
return self.ws_proto.subprotocol

@property
def connection(self):
deprecation(
"The connection property has been deprecated and will be removed. "
"Please use the ws_proto property instead going forward.",
22.6,
)
return self.ws_proto

def pause_frames(self):
if not self.can_pause:
Expand Down Expand Up @@ -299,15 +320,15 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
# Not draining the write buffer is acceptable in this context.

# clear the send buffer
_ = self.connection.data_to_send()
_ = self.ws_proto.data_to_send()
# If we're not already CLOSED or CLOSING, then send the close.
if self.connection.state is OPEN:
if self.ws_proto.state is OPEN:
if code in (1000, 1001):
self.connection.send_close(code, reason)
self.ws_proto.send_close(code, reason)
else:
self.connection.fail(code, reason)
self.ws_proto.fail(code, reason)
try:
data_to_send = self.connection.data_to_send()
data_to_send = self.ws_proto.data_to_send()
while (
len(data_to_send)
and self.io_proto
Expand All @@ -321,7 +342,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
...
if code == 1006:
# Special case: 1006 consider the transport already closed
self.connection.state = CLOSED
self.ws_proto.state = CLOSED
if self.data_finished_fut and not self.data_finished_fut.done():
# We have a graceful auto-closer. Use it to close the connection.
self.data_finished_fut.cancel()
Expand All @@ -342,10 +363,10 @@ def end_connection(self, code=1000, reason=""):
# In Python Version 3.7: pause_reading is idempotent
# i.e. it can be called when the transport is already paused or closed.
self.io_proto.transport.pause_reading()
if self.connection.state == OPEN:
data_to_send = self.connection.data_to_send()
self.connection.send_close(code, reason)
data_to_send.extend(self.connection.data_to_send())
if self.ws_proto.state == OPEN:
data_to_send = self.ws_proto.data_to_send()
self.ws_proto.send_close(code, reason)
data_to_send.extend(self.ws_proto.data_to_send())
try:
while (
len(data_to_send)
Expand Down Expand Up @@ -454,7 +475,7 @@ def abort_pings(self) -> None:
Raise ConnectionClosed in pending keepalive pings.
They'll never receive a pong once the connection is closed.
"""
if self.connection.state is not CLOSED:
if self.ws_proto.state is not CLOSED:
raise ServerError(
"Webscoket about_pings should only be called "
"after connection state is changed to CLOSED"
Expand Down Expand Up @@ -483,9 +504,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
self.fail_connection(code, reason)
return
async with self.conn_mutex:
if self.connection.state is OPEN:
self.connection.send_close(code, reason)
data_to_send = self.connection.data_to_send()
if self.ws_proto.state is OPEN:
self.ws_proto.send_close(code, reason)
data_to_send = self.ws_proto.data_to_send()
await self.send_data(data_to_send)

async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
Expand Down Expand Up @@ -515,7 +536,7 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
"already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -566,7 +587,7 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]:
"for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -625,7 +646,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]:
"is already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -666,7 +687,7 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""
async with self.conn_mutex:

if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed(
"Cannot write to websocket interface after it is closed."
)
Expand All @@ -679,12 +700,12 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None:
# strings and bytes-like objects are iterable.

if isinstance(message, str):
self.connection.send_text(message.encode("utf-8"))
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_text(message.encode("utf-8"))
await self.send_data(self.ws_proto.data_to_send())

elif isinstance(message, (bytes, bytearray, memoryview)):
self.connection.send_binary(message)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_binary(message)
await self.send_data(self.ws_proto.data_to_send())

elif isinstance(message, Mapping):
# Catch a common mistake -- passing a dict to send().
Expand Down Expand Up @@ -713,7 +734,7 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future:
(which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed(
"Cannot send a ping when the websocket interface "
"is closed."
Expand Down Expand Up @@ -741,8 +762,8 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future:

self.pings[data] = self.io_proto.loop.create_future()

self.connection.send_ping(data)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_ping(data)
await self.send_data(self.ws_proto.data_to_send())

return asyncio.shield(self.pings[data])

Expand All @@ -754,15 +775,15 @@ async def pong(self, data: Data = b"") -> None:
be a string (which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
# Cannot send pong after transport is shutting down
return
if isinstance(data, str):
data = data.encode("utf-8")
elif isinstance(data, (bytearray, memoryview)):
data = bytes(data)
self.connection.send_pong(data)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_pong(data)
await self.send_data(self.ws_proto.data_to_send())

async def send_data(self, data_to_send):
for data in data_to_send:
Expand All @@ -784,17 +805,17 @@ async def send_data(self, data_to_send):
SanicProtocol.close(self.io_proto, timeout=1.0)

async def async_data_received(self, data_to_send, events_to_process):
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
# receiving data can generate data to send (eg, pong for a ping)
# send connection.data_to_send()
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)

def data_received(self, data):
self.connection.receive_data(data)
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
self.ws_proto.receive_data(data)
data_to_send = self.ws_proto.data_to_send()
events_to_process = self.ws_proto.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0:
asyncio.create_task(
self.async_data_received(data_to_send, events_to_process)
Expand All @@ -803,7 +824,7 @@ def data_received(self, data):
async def async_eof_received(self, data_to_send, events_to_process):
# receiving EOF can generate data to send
# send connection.data_to_send()
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)
Expand All @@ -823,9 +844,9 @@ async def async_eof_received(self, data_to_send, events_to_process):
SanicProtocol.close(self.io_proto, timeout=1.0)

def eof_received(self) -> Optional[bool]:
self.connection.receive_eof()
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
self.ws_proto.receive_eof()
data_to_send = self.ws_proto.data_to_send()
events_to_process = self.ws_proto.events_received()
asyncio.create_task(
self.async_eof_received(data_to_send, events_to_process)
)
Expand All @@ -835,11 +856,11 @@ def connection_lost(self, exc):
"""
The WebSocket Connection is Closed.
"""
if not self.connection.state == CLOSED:
if not self.ws_proto.state == CLOSED:
# signal to the websocket connection handler
# we've lost the connection
self.connection.fail(code=1006)
self.connection.state = CLOSED
self.ws_proto.fail(code=1006)
self.ws_proto.state = CLOSED

self.abort_pings()
if self.connection_lost_waiter:
Expand Down