diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py index c53a65a58d..87881b84de 100644 --- a/sanic/server/websockets/connection.py +++ b/sanic/server/websockets/connection.py @@ -9,8 +9,10 @@ Union, ) +from sanic.exceptions import InvalidUsage -ASIMessage = MutableMapping[str, Any] + +ASGIMessage = MutableMapping[str, Any] class WebSocketConnection: @@ -25,8 +27,8 @@ class WebSocketConnection: def __init__( self, - send: Callable[[ASIMessage], Awaitable[None]], - receive: Callable[[], Awaitable[ASIMessage]], + send: Callable[[ASGIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASGIMessage]], subprotocols: Optional[List[str]] = None, ) -> None: self._send = send @@ -47,7 +49,13 @@ async def recv(self, *args, **kwargs) -> Optional[str]: message = await self._receive() if message["type"] == "websocket.receive": - return message["text"] + try: + return message["text"] + except KeyError: + try: + return message["bytes"].decode() + except KeyError: + raise InvalidUsage("Bad ASGI message received") elif message["type"] == "websocket.disconnect": pass