diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index e55ebdd610..ce9c62447d 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -1,7 +1,13 @@ from typing import TYPE_CHECKING, Optional, Sequence, cast -from websockets.connection import CLOSED, CLOSING, OPEN -from websockets.server import ServerConnection + +try: # websockets < 11.0 + from websockets.connection import State + from websockets.server import ServerConnection as ServerProtocol +except ImportError: # websockets >= 11.0 + from websockets.protocol import State # type: ignore + from websockets.server import ServerProtocol # type: ignore + from websockets.typing import Subprotocol from sanic.exceptions import ServerError @@ -15,6 +21,11 @@ from websockets import http11 +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + class WebSocketProtocol(HttpProtocol): __slots__ = ( "websocket", @@ -74,7 +85,7 @@ def close_if_idle(self): # Called by Sanic Server when shutting down # If we've upgraded to websocket, shut it down if self.websocket is not None: - if self.websocket.connection.state in (CLOSING, CLOSED): + if self.websocket.ws_proto.state in (CLOSING, CLOSED): return True elif self.websocket.loop is not None: self.websocket.loop.create_task(self.websocket.close(1001)) @@ -90,7 +101,7 @@ async def websocket_handshake( try: if subprotocols is not None: # subprotocols can be a set or frozenset, - # but ServerConnection needs a list + # but ServerProtocol needs a list subprotocols = cast( Optional[Sequence[Subprotocol]], list( @@ -100,13 +111,13 @@ async def websocket_handshake( ] ), ) - ws_conn = ServerConnection( + ws_proto = ServerProtocol( max_size=self.websocket_max_size, subprotocols=subprotocols, state=OPEN, logger=logger, ) - resp: "http11.Response" = ws_conn.accept(request) + resp: "http11.Response" = ws_proto.accept(request) except Exception: msg = ( "Failed to open a WebSocket connection.\n" @@ -129,7 +140,7 @@ async def websocket_handshake( else: raise ServerError(resp.body, resp.status_code) self.websocket = WebsocketImplProtocol( - ws_conn, + ws_proto, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout, close_timeout=self.websocket_timeout, diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 6104c1f06b..5e258e972f 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -12,25 +12,37 @@ Union, ) -from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.exceptions import ( ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, ) from websockets.frames import Frame, Opcode -from websockets.server import ServerConnection + + +try: # websockets < 11.0 + from websockets.connection import Event, State + from websockets.server import ServerConnection as ServerProtocol +except ImportError: # websockets >= 11.0 + from websockets.protocol import Event, State # type: ignore + from websockets.server import ServerProtocol # type: ignore + from websockets.typing import Data -from sanic.log import error_logger, logger +from sanic.log import deprecation, error_logger, logger from sanic.server.protocols.base_protocol import SanicProtocol from ...exceptions import ServerError, WebsocketClosed from .frame import WebsocketFrameAssembler +OPEN = State.OPEN +CLOSING = State.CLOSING +CLOSED = State.CLOSED + + class WebsocketImplProtocol: - connection: ServerConnection + ws_proto: ServerProtocol io_proto: Optional[SanicProtocol] loop: Optional[asyncio.AbstractEventLoop] max_queue: int @@ -56,14 +68,14 @@ class WebsocketImplProtocol: def __init__( self, - connection, + ws_proto, max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, close_timeout: float = 10, loop=None, ): - self.connection = connection + self.ws_proto = ws_proto self.io_proto = None self.loop = None self.max_queue = max_queue @@ -85,7 +97,16 @@ def __init__( @property def subprotocol(self): - return self.connection.subprotocol + return self.ws_proto.subprotocol + + @property + def connection(self): + deprecation( + "The connection property has been deprecated and will be removed. " + "Please use the ws_proto property instead going forward.", + 22.6, + ) + return self.ws_proto def pause_frames(self): if not self.can_pause: @@ -299,15 +320,15 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: # Not draining the write buffer is acceptable in this context. # clear the send buffer - _ = self.connection.data_to_send() + _ = self.ws_proto.data_to_send() # If we're not already CLOSED or CLOSING, then send the close. - if self.connection.state is OPEN: + if self.ws_proto.state is OPEN: if code in (1000, 1001): - self.connection.send_close(code, reason) + self.ws_proto.send_close(code, reason) else: - self.connection.fail(code, reason) + self.ws_proto.fail(code, reason) try: - data_to_send = self.connection.data_to_send() + data_to_send = self.ws_proto.data_to_send() while ( len(data_to_send) and self.io_proto @@ -321,7 +342,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: ... if code == 1006: # Special case: 1006 consider the transport already closed - self.connection.state = CLOSED + self.ws_proto.state = CLOSED if self.data_finished_fut and not self.data_finished_fut.done(): # We have a graceful auto-closer. Use it to close the connection. self.data_finished_fut.cancel() @@ -342,10 +363,10 @@ def end_connection(self, code=1000, reason=""): # In Python Version 3.7: pause_reading is idempotent # i.e. it can be called when the transport is already paused or closed. self.io_proto.transport.pause_reading() - if self.connection.state == OPEN: - data_to_send = self.connection.data_to_send() - self.connection.send_close(code, reason) - data_to_send.extend(self.connection.data_to_send()) + if self.ws_proto.state == OPEN: + data_to_send = self.ws_proto.data_to_send() + self.ws_proto.send_close(code, reason) + data_to_send.extend(self.ws_proto.data_to_send()) try: while ( len(data_to_send) @@ -454,7 +475,7 @@ def abort_pings(self) -> None: Raise ConnectionClosed in pending keepalive pings. They'll never receive a pong once the connection is closed. """ - if self.connection.state is not CLOSED: + if self.ws_proto.state is not CLOSED: raise ServerError( "Webscoket about_pings should only be called " "after connection state is changed to CLOSED" @@ -483,9 +504,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None: self.fail_connection(code, reason) return async with self.conn_mutex: - if self.connection.state is OPEN: - self.connection.send_close(code, reason) - data_to_send = self.connection.data_to_send() + if self.ws_proto.state is OPEN: + self.ws_proto.send_close(code, reason) + data_to_send = self.ws_proto.data_to_send() await self.send_data(data_to_send) async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: @@ -515,7 +536,7 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: "already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state is CLOSED: + if self.ws_proto.state is CLOSED: self.recv_lock.release() raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." @@ -566,7 +587,7 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]: "for the next message" ) await self.recv_lock.acquire() - if self.connection.state is CLOSED: + if self.ws_proto.state is CLOSED: self.recv_lock.release() raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." @@ -625,7 +646,7 @@ async def recv_streaming(self) -> AsyncIterator[Data]: "is already waiting for the next message" ) await self.recv_lock.acquire() - if self.connection.state is CLOSED: + if self.ws_proto.state is CLOSED: self.recv_lock.release() raise WebsocketClosed( "Cannot receive from websocket interface after it is closed." @@ -666,7 +687,7 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: """ async with self.conn_mutex: - if self.connection.state in (CLOSED, CLOSING): + if self.ws_proto.state in (CLOSED, CLOSING): raise WebsocketClosed( "Cannot write to websocket interface after it is closed." ) @@ -679,12 +700,12 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: # strings and bytes-like objects are iterable. if isinstance(message, str): - self.connection.send_text(message.encode("utf-8")) - await self.send_data(self.connection.data_to_send()) + self.ws_proto.send_text(message.encode("utf-8")) + await self.send_data(self.ws_proto.data_to_send()) elif isinstance(message, (bytes, bytearray, memoryview)): - self.connection.send_binary(message) - await self.send_data(self.connection.data_to_send()) + self.ws_proto.send_binary(message) + await self.send_data(self.ws_proto.data_to_send()) elif isinstance(message, Mapping): # Catch a common mistake -- passing a dict to send(). @@ -713,7 +734,7 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: (which will be encoded to UTF-8) or a bytes-like object. """ async with self.conn_mutex: - if self.connection.state in (CLOSED, CLOSING): + if self.ws_proto.state in (CLOSED, CLOSING): raise WebsocketClosed( "Cannot send a ping when the websocket interface " "is closed." @@ -741,8 +762,8 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: self.pings[data] = self.io_proto.loop.create_future() - self.connection.send_ping(data) - await self.send_data(self.connection.data_to_send()) + self.ws_proto.send_ping(data) + await self.send_data(self.ws_proto.data_to_send()) return asyncio.shield(self.pings[data]) @@ -754,15 +775,15 @@ async def pong(self, data: Data = b"") -> None: be a string (which will be encoded to UTF-8) or a bytes-like object. """ async with self.conn_mutex: - if self.connection.state in (CLOSED, CLOSING): + if self.ws_proto.state in (CLOSED, CLOSING): # Cannot send pong after transport is shutting down return if isinstance(data, str): data = data.encode("utf-8") elif isinstance(data, (bytearray, memoryview)): data = bytes(data) - self.connection.send_pong(data) - await self.send_data(self.connection.data_to_send()) + self.ws_proto.send_pong(data) + await self.send_data(self.ws_proto.data_to_send()) async def send_data(self, data_to_send): for data in data_to_send: @@ -784,7 +805,7 @@ async def send_data(self, data_to_send): SanicProtocol.close(self.io_proto, timeout=1.0) async def async_data_received(self, data_to_send, events_to_process): - if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0: # receiving data can generate data to send (eg, pong for a ping) # send connection.data_to_send() await self.send_data(data_to_send) @@ -792,9 +813,9 @@ async def async_data_received(self, data_to_send, events_to_process): await self.process_events(events_to_process) def data_received(self, data): - self.connection.receive_data(data) - data_to_send = self.connection.data_to_send() - events_to_process = self.connection.events_received() + self.ws_proto.receive_data(data) + data_to_send = self.ws_proto.data_to_send() + events_to_process = self.ws_proto.events_received() if len(data_to_send) > 0 or len(events_to_process) > 0: asyncio.create_task( self.async_data_received(data_to_send, events_to_process) @@ -803,7 +824,7 @@ def data_received(self, data): async def async_eof_received(self, data_to_send, events_to_process): # receiving EOF can generate data to send # send connection.data_to_send() - if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + if self.ws_proto.state in (OPEN, CLOSING) and len(data_to_send) > 0: await self.send_data(data_to_send) if len(events_to_process) > 0: await self.process_events(events_to_process) @@ -823,9 +844,9 @@ async def async_eof_received(self, data_to_send, events_to_process): SanicProtocol.close(self.io_proto, timeout=1.0) def eof_received(self) -> Optional[bool]: - self.connection.receive_eof() - data_to_send = self.connection.data_to_send() - events_to_process = self.connection.events_received() + self.ws_proto.receive_eof() + data_to_send = self.ws_proto.data_to_send() + events_to_process = self.ws_proto.events_received() asyncio.create_task( self.async_eof_received(data_to_send, events_to_process) ) @@ -835,11 +856,11 @@ def connection_lost(self, exc): """ The WebSocket Connection is Closed. """ - if not self.connection.state == CLOSED: + if not self.ws_proto.state == CLOSED: # signal to the websocket connection handler # we've lost the connection - self.connection.fail(code=1006) - self.connection.state = CLOSED + self.ws_proto.fail(code=1006) + self.ws_proto.state = CLOSED self.abort_pings() if self.connection_lost_waiter: