Skip to content

Commit

Permalink
Add compatibility with websockets 11.0. (#2609)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
  • Loading branch information
aaugustin and ahopkins committed Nov 29, 2022
1 parent beae35f commit 4c14910
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 53 deletions.
25 changes: 18 additions & 7 deletions sanic/server/protocols/websocket_protocol.py
@@ -1,7 +1,13 @@
from typing import TYPE_CHECKING, Optional, Sequence, cast

from websockets.connection import CLOSED, CLOSING, OPEN
from websockets.server import ServerConnection

try: # websockets < 11.0
from websockets.connection import State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import State # type: ignore
from websockets.server import ServerProtocol # type: ignore

from websockets.typing import Subprotocol

from sanic.exceptions import ServerError
Expand All @@ -15,6 +21,11 @@
from websockets import http11


OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED


class WebSocketProtocol(HttpProtocol):
__slots__ = (
"websocket",
Expand Down Expand Up @@ -74,7 +85,7 @@ def close_if_idle(self):
# Called by Sanic Server when shutting down
# If we've upgraded to websocket, shut it down
if self.websocket is not None:
if self.websocket.connection.state in (CLOSING, CLOSED):
if self.websocket.ws_proto.state in (CLOSING, CLOSED):
return True
elif self.websocket.loop is not None:
self.websocket.loop.create_task(self.websocket.close(1001))
Expand All @@ -90,7 +101,7 @@ async def websocket_handshake(
try:
if subprotocols is not None:
# subprotocols can be a set or frozenset,
# but ServerConnection needs a list
# but ServerProtocol needs a list
subprotocols = cast(
Optional[Sequence[Subprotocol]],
list(
Expand All @@ -100,13 +111,13 @@ async def websocket_handshake(
]
),
)
ws_conn = ServerConnection(
ws_proto = ServerProtocol(
max_size=self.websocket_max_size,
subprotocols=subprotocols,
state=OPEN,
logger=logger,
)
resp: "http11.Response" = ws_conn.accept(request)
resp: "http11.Response" = ws_proto.accept(request)
except Exception:
msg = (
"Failed to open a WebSocket connection.\n"
Expand All @@ -129,7 +140,7 @@ async def websocket_handshake(
else:
raise ServerError(resp.body, resp.status_code)
self.websocket = WebsocketImplProtocol(
ws_conn,
ws_proto,
ping_interval=self.websocket_ping_interval,
ping_timeout=self.websocket_ping_timeout,
close_timeout=self.websocket_timeout,
Expand Down
113 changes: 67 additions & 46 deletions sanic/server/websockets/impl.py
Expand Up @@ -12,25 +12,37 @@
Union,
)

from websockets.connection import CLOSED, CLOSING, OPEN, Event
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
)
from websockets.frames import Frame, Opcode
from websockets.server import ServerConnection


try: # websockets < 11.0
from websockets.connection import Event, State
from websockets.server import ServerConnection as ServerProtocol
except ImportError: # websockets >= 11.0
from websockets.protocol import Event, State # type: ignore
from websockets.server import ServerProtocol # type: ignore

from websockets.typing import Data

from sanic.log import error_logger, logger
from sanic.log import deprecation, error_logger, logger
from sanic.server.protocols.base_protocol import SanicProtocol

from ...exceptions import ServerError, WebsocketClosed
from .frame import WebsocketFrameAssembler


OPEN = State.OPEN
CLOSING = State.CLOSING
CLOSED = State.CLOSED


class WebsocketImplProtocol:
connection: ServerConnection
ws_proto: ServerProtocol
io_proto: Optional[SanicProtocol]
loop: Optional[asyncio.AbstractEventLoop]
max_queue: int
Expand All @@ -56,14 +68,14 @@ class WebsocketImplProtocol:

def __init__(
self,
connection,
ws_proto,
max_queue=None,
ping_interval: Optional[float] = 20,
ping_timeout: Optional[float] = 20,
close_timeout: float = 10,
loop=None,
):
self.connection = connection
self.ws_proto = ws_proto
self.io_proto = None
self.loop = None
self.max_queue = max_queue
Expand All @@ -85,7 +97,16 @@ def __init__(

@property
def subprotocol(self):
return self.connection.subprotocol
return self.ws_proto.subprotocol

@property
def connection(self):
deprecation(
"The connection property has been deprecated and will be removed. "
"Please use the ws_proto property instead going forward.",
22.6,
)
return self.ws_proto

def pause_frames(self):
if not self.can_pause:
Expand Down Expand Up @@ -299,15 +320,15 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
# Not draining the write buffer is acceptable in this context.

# clear the send buffer
_ = self.connection.data_to_send()
_ = self.ws_proto.data_to_send()
# If we're not already CLOSED or CLOSING, then send the close.
if self.connection.state is OPEN:
if self.ws_proto.state is OPEN:
if code in (1000, 1001):
self.connection.send_close(code, reason)
self.ws_proto.send_close(code, reason)
else:
self.connection.fail(code, reason)
self.ws_proto.fail(code, reason)
try:
data_to_send = self.connection.data_to_send()
data_to_send = self.ws_proto.data_to_send()
while (
len(data_to_send)
and self.io_proto
Expand All @@ -321,7 +342,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool:
...
if code == 1006:
# Special case: 1006 consider the transport already closed
self.connection.state = CLOSED
self.ws_proto.state = CLOSED
if self.data_finished_fut and not self.data_finished_fut.done():
# We have a graceful auto-closer. Use it to close the connection.
self.data_finished_fut.cancel()
Expand All @@ -342,10 +363,10 @@ def end_connection(self, code=1000, reason=""):
# In Python Version 3.7: pause_reading is idempotent
# i.e. it can be called when the transport is already paused or closed.
self.io_proto.transport.pause_reading()
if self.connection.state == OPEN:
data_to_send = self.connection.data_to_send()
self.connection.send_close(code, reason)
data_to_send.extend(self.connection.data_to_send())
if self.ws_proto.state == OPEN:
data_to_send = self.ws_proto.data_to_send()
self.ws_proto.send_close(code, reason)
data_to_send.extend(self.ws_proto.data_to_send())
try:
while (
len(data_to_send)
Expand Down Expand Up @@ -454,7 +475,7 @@ def abort_pings(self) -> None:
Raise ConnectionClosed in pending keepalive pings.
They'll never receive a pong once the connection is closed.
"""
if self.connection.state is not CLOSED:
if self.ws_proto.state is not CLOSED:
raise ServerError(
"Webscoket about_pings should only be called "
"after connection state is changed to CLOSED"
Expand Down Expand Up @@ -483,9 +504,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
self.fail_connection(code, reason)
return
async with self.conn_mutex:
if self.connection.state is OPEN:
self.connection.send_close(code, reason)
data_to_send = self.connection.data_to_send()
if self.ws_proto.state is OPEN:
self.ws_proto.send_close(code, reason)
data_to_send = self.ws_proto.data_to_send()
await self.send_data(data_to_send)

async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
Expand Down Expand Up @@ -515,7 +536,7 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]:
"already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -566,7 +587,7 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]:
"for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -625,7 +646,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]:
"is already waiting for the next message"
)
await self.recv_lock.acquire()
if self.connection.state is CLOSED:
if self.ws_proto.state is CLOSED:
self.recv_lock.release()
raise WebsocketClosed(
"Cannot receive from websocket interface after it is closed."
Expand Down Expand Up @@ -666,7 +687,7 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None:
"""
async with self.conn_mutex:

if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed(
"Cannot write to websocket interface after it is closed."
)
Expand All @@ -679,12 +700,12 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None:
# strings and bytes-like objects are iterable.

if isinstance(message, str):
self.connection.send_text(message.encode("utf-8"))
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_text(message.encode("utf-8"))
await self.send_data(self.ws_proto.data_to_send())

elif isinstance(message, (bytes, bytearray, memoryview)):
self.connection.send_binary(message)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_binary(message)
await self.send_data(self.ws_proto.data_to_send())

elif isinstance(message, Mapping):
# Catch a common mistake -- passing a dict to send().
Expand Down Expand Up @@ -713,7 +734,7 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future:
(which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
raise WebsocketClosed(
"Cannot send a ping when the websocket interface "
"is closed."
Expand Down Expand Up @@ -741,8 +762,8 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future:

self.pings[data] = self.io_proto.loop.create_future()

self.connection.send_ping(data)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_ping(data)
await self.send_data(self.ws_proto.data_to_send())

return asyncio.shield(self.pings[data])

Expand All @@ -754,15 +775,15 @@ async def pong(self, data: Data = b"") -> None:
be a string (which will be encoded to UTF-8) or a bytes-like object.
"""
async with self.conn_mutex:
if self.connection.state in (CLOSED, CLOSING):
if self.ws_proto.state in (CLOSED, CLOSING):
# Cannot send pong after transport is shutting down
return
if isinstance(data, str):
data = data.encode("utf-8")
elif isinstance(data, (bytearray, memoryview)):
data = bytes(data)
self.connection.send_pong(data)
await self.send_data(self.connection.data_to_send())
self.ws_proto.send_pong(data)
await self.send_data(self.ws_proto.data_to_send())

async def send_data(self, data_to_send):
for data in data_to_send:
Expand All @@ -784,17 +805,17 @@ async def send_data(self, data_to_send):
SanicProtocol.close(self.io_proto, timeout=1.0)

async def async_data_received(self, data_to_send, events_to_process):
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
# receiving data can generate data to send (eg, pong for a ping)
# send connection.data_to_send()
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)

def data_received(self, data):
self.connection.receive_data(data)
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
self.ws_proto.receive_data(data)
data_to_send = self.ws_proto.data_to_send()
events_to_process = self.ws_proto.events_received()
if len(data_to_send) > 0 or len(events_to_process) > 0:
asyncio.create_task(
self.async_data_received(data_to_send, events_to_process)
Expand All @@ -803,7 +824,7 @@ def data_received(self, data):
async def async_eof_received(self, data_to_send, events_to_process):
# receiving EOF can generate data to send
# send connection.data_to_send()
if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0:
if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0:
await self.send_data(data_to_send)
if len(events_to_process) > 0:
await self.process_events(events_to_process)
Expand All @@ -823,9 +844,9 @@ async def async_eof_received(self, data_to_send, events_to_process):
SanicProtocol.close(self.io_proto, timeout=1.0)

def eof_received(self) -> Optional[bool]:
self.connection.receive_eof()
data_to_send = self.connection.data_to_send()
events_to_process = self.connection.events_received()
self.ws_proto.receive_eof()
data_to_send = self.ws_proto.data_to_send()
events_to_process = self.ws_proto.events_received()
asyncio.create_task(
self.async_eof_received(data_to_send, events_to_process)
)
Expand All @@ -835,11 +856,11 @@ def connection_lost(self, exc):
"""
The WebSocket Connection is Closed.
"""
if not self.connection.state == CLOSED:
if not self.ws_proto.state == CLOSED:
# signal to the websocket connection handler
# we've lost the connection
self.connection.fail(code=1006)
self.connection.state = CLOSED
self.ws_proto.fail(code=1006)
self.ws_proto.state = CLOSED

self.abort_pings()
if self.connection_lost_waiter:
Expand Down

0 comments on commit 4c14910

Please sign in to comment.