Skip to content

Commit

Permalink
Fix slackapi#1110 Socket Mode disconnection issue with the aiohttp-ba…
Browse files Browse the repository at this point in the history
…sed client
  • Loading branch information
seratch committed Sep 9, 2021
1 parent 9f3240d commit ed8292d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 63 deletions.
1 change: 1 addition & 0 deletions integration_tests/samples/socket_mode/aiohttp_example.py
Expand Up @@ -17,6 +17,7 @@ async def main():
web_client=AsyncWebClient(
token=os.environ.get("SLACK_SDK_TEST_SOCKET_MODE_BOT_TOKEN")
),
trace_enabled=True,
)

async def process(client: SocketModeClient, req: SocketModeRequest):
Expand Down
188 changes: 125 additions & 63 deletions slack_sdk/socket_mode/aiohttp/__init__.py
@@ -1,4 +1,4 @@
"""aiohttp bassd Socket Mode client
"""aiohttp based Socket Mode client
* https://api.slack.com/apis/connections/socket
* https://slack.dev/python-slack-sdk/socket-mode/
Expand All @@ -7,7 +7,8 @@
"""
import asyncio
import logging
from asyncio import Future, Lock
import time
from asyncio import Future, Lock, Task
from asyncio import Queue
from logging import Logger
from typing import Union, Optional, List, Callable, Awaitable
Expand Down Expand Up @@ -52,12 +53,16 @@ class SocketModeClient(AsyncBaseSocketModeClient):

proxy: Optional[str]
ping_interval: float
trace_enabled: bool

last_ping_pong_time: Optional[float]
current_session: Optional[ClientWebSocketResponse]
current_session_monitor: Optional[Future]

auto_reconnect_enabled: bool
default_auto_reconnect_enabled: bool
closed: bool
stale: bool
connect_operation_lock: Lock

on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]]
Expand All @@ -71,7 +76,8 @@ def __init__(
web_client: Optional[AsyncWebClient] = None,
proxy: Optional[str] = None,
auto_reconnect_enabled: bool = True,
ping_interval: float = 10,
ping_interval: float = 5,
trace_enabled: bool = False,
on_message_listeners: Optional[List[Callable[[WSMessage], None]]] = None,
on_error_listeners: Optional[List[Callable[[WSMessage], None]]] = None,
on_close_listeners: Optional[List[Callable[[WSMessage], None]]] = None,
Expand All @@ -84,6 +90,7 @@ def __init__(
web_client: Web API client
auto_reconnect_enabled: True if automatic reconnection is enabled (default: True)
ping_interval: interval for ping-pong with Slack servers (seconds)
trace_enabled: True if more verbose logs to see what's happening under the hood
proxy: the HTTP proxy URL
on_message_listeners: listener functions for on_message
on_error_listeners: listener functions for on_error
Expand All @@ -93,6 +100,7 @@ def __init__(
self.logger = logger or logging.getLogger(__name__)
self.web_client = web_client or AsyncWebClient()
self.closed = False
self.stale = False
self.connect_operation_lock = Lock()
self.proxy = proxy
if self.proxy is None or len(self.proxy.strip()) == 0:
Expand All @@ -103,6 +111,8 @@ def __init__(
self.default_auto_reconnect_enabled = auto_reconnect_enabled
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.ping_interval = ping_interval
self.trace_enabled = trace_enabled
self.last_ping_pong_time = None

self.wss_uri = None
self.message_queue = Queue()
Expand All @@ -126,70 +136,116 @@ def __init__(
self.message_processor = asyncio.ensure_future(self.process_messages())

async def monitor_current_session(self) -> None:
while not self.closed:
await asyncio.sleep(self.ping_interval)
try:
if self.auto_reconnect_enabled and (
self.current_session is None or self.current_session.closed
):
self.logger.info(
"The session seems to be already closed. Going to reconnect..."
try:
while not self.closed:
try:
await asyncio.sleep(self.ping_interval)
if self.current_session is not None:
t = time.time()
if self.last_ping_pong_time is None:
self.last_ping_pong_time = float(t)
await self.current_session.ping(f"ping-pong:{t}")

if self.auto_reconnect_enabled:
should_reconnect = False
if self.current_session is None or self.current_session.closed:
self.logger.info(
"The session seems to be already closed. Going to reconnect..."
)
should_reconnect = True

if self.last_ping_pong_time is not None:
disconnected_seconds = int(time.time() - self.last_ping_pong_time)
if disconnected_seconds >= (self.ping_interval * 4):
self.logger.info(
"The connection seems to be stale. Disconnecting..."
f" reason: disconnected for {disconnected_seconds}+ seconds)"
)
self.stale = True
self.last_ping_pong_time = None
should_reconnect = True

if should_reconnect is True or not await self.is_connected():
await self.connect_to_new_endpoint()
self.logger.info("Reconnection done.")

except Exception as e:
self.logger.error(
"Failed to check the current session or reconnect to the server "
f"(error: {type(e).__name__}, message: {e})"
)
await self.connect_to_new_endpoint()
except Exception as e:
self.logger.error(
"Failed to check the current session or reconnect to the server "
f"(error: {type(e).__name__}, message: {e})"
)
except asyncio.CancelledError:
if self.trace_enabled:
self.logger.debug("The running monitor_current_session task is now cancelled")
raise

async def receive_messages(self) -> None:
consecutive_error_count = 0
while not self.closed:
try:
message: WSMessage = await self.current_session.receive()
if self.logger.level <= logging.DEBUG:
type = WSMsgType(message.type)
message_type = type.name if type is not None else message.type
message_data = message.data
if isinstance(message_data, bytes):
message_data = message_data.decode("utf-8")
self.logger.debug(
f"Received message (type: {message_type}, data: {message_data}, extra: {message.extra})"
)
if message is not None:
if message.type == WSMsgType.TEXT:
try:
consecutive_error_count = 0
while not self.closed:
try:
message: WSMessage = await self.current_session.receive()
if self.trace_enabled and self.logger.level <= logging.DEBUG:
type = WSMsgType(message.type)
message_type = type.name if type is not None else message.type
message_data = message.data
await self.enqueue_message(message_data)
for listener in self.on_message_listeners:
await listener(message)
elif message.type == WSMsgType.CLOSE:
if self.auto_reconnect_enabled:
self.logger.info(
"Received CLOSE event. Going to reconnect..."
if isinstance(message_data, bytes):
message_data = message_data.decode("utf-8")
if len(message_data) > 0:
# To skip the empty message that Slack server-side often sends
self.logger.debug(
f"Received message (type: {message_type}, data: {message_data}, extra: {message.extra})"
)
await self.connect_to_new_endpoint()
for listener in self.on_close_listeners:
await listener(message)
elif message.type == WSMsgType.ERROR:
for listener in self.on_error_listeners:
await listener(message)
elif message.type == WSMsgType.CLOSED:
if message is not None:
if message.type == WSMsgType.TEXT:
message_data = message.data
await self.enqueue_message(message_data)
for listener in self.on_message_listeners:
await listener(message)
elif message.type == WSMsgType.CLOSE:
if self.auto_reconnect_enabled:
self.logger.info(
"Received CLOSE event. Going to reconnect..."
)
await self.connect_to_new_endpoint()
for listener in self.on_close_listeners:
await listener(message)
elif message.type == WSMsgType.ERROR:
for listener in self.on_error_listeners:
await listener(message)
elif message.type == WSMsgType.CLOSED:
await asyncio.sleep(self.ping_interval)
continue
elif message.type == WSMsgType.PING:
await self.current_session.pong(message.data)
continue
elif message.type == WSMsgType.PONG:
elements = message.data.decode('utf-8').split(":")
if len(elements) == 2:
try:
self.last_ping_pong_time = float(elements[1])
except:
pass
continue
consecutive_error_count = 0
except Exception as e:
consecutive_error_count += 1
self.logger.error(
f"Failed to receive or enqueue a message: {type(e).__name__}, {e}"
)
if isinstance(e, ClientConnectionError):
await asyncio.sleep(self.ping_interval)
continue
consecutive_error_count = 0
except Exception as e:
consecutive_error_count += 1
self.logger.error(
f"Failed to receive or enqueue a message: {type(e).__name__}, {e}"
)
if isinstance(e, ClientConnectionError):
await asyncio.sleep(self.ping_interval)
else:
await asyncio.sleep(consecutive_error_count)
else:
await asyncio.sleep(consecutive_error_count)
except asyncio.CancelledError:
if self.trace_enabled:
self.logger.debug("The running receive_messages task is now cancelled")
raise

async def is_connected(self) -> bool:
return (
not self.closed
and not self.stale
and self.current_session is not None
and not self.current_session.closed
)
Expand All @@ -200,19 +256,25 @@ async def connect(self):
self.wss_uri = await self.issue_new_wss_url()
self.current_session = await self.aiohttp_client_session.ws_connect(
self.wss_uri,
autoping=False,
heartbeat=self.ping_interval,
proxy=self.proxy,
)
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.stale = False
self.logger.info("A new session has been established")

if self.current_session_monitor is None:
self.current_session_monitor = asyncio.ensure_future(
self.monitor_current_session()
)
if self.current_session_monitor is not None:
self.current_session_monitor.cancel()

self.current_session_monitor = asyncio.ensure_future(
self.monitor_current_session()
)

if self.message_receiver is not None:
self.message_receiver.cancel()

if self.message_receiver is None:
self.message_receiver = asyncio.ensure_future(self.receive_messages())
self.message_receiver = asyncio.ensure_future(self.receive_messages())

if old_session is not None:
await old_session.close()
Expand Down
5 changes: 5 additions & 0 deletions slack_sdk/socket_mode/async_client.py
Expand Up @@ -22,6 +22,7 @@ class AsyncBaseSocketModeClient:
app_token: str
wss_uri: str
auto_reconnect_enabled: bool
trace_enabled: bool
closed: bool
connect_operation_lock: Lock

Expand Down Expand Up @@ -72,12 +73,16 @@ async def disconnect(self):
async def connect_to_new_endpoint(self, force: bool = False):
try:
await self.connect_operation_lock.acquire()
if self.trace_enabled:
self.logger.debug("For reconnection, the connect_operation_lock was acquired")
if force or not await self.is_connected():
self.wss_uri = await self.issue_new_wss_url()
await self.connect()
finally:
if self.connect_operation_lock.locked() is True:
self.connect_operation_lock.release()
if self.trace_enabled:
self.logger.debug("The connect_operation_lock for reconnection was released")

async def close(self):
self.closed = True
Expand Down
5 changes: 5 additions & 0 deletions slack_sdk/socket_mode/websockets/__init__.py
Expand Up @@ -50,6 +50,8 @@ class SocketModeClient(AsyncBaseSocketModeClient):
message_processor: Future

ping_interval: float
trace_enabled: bool

current_session: Optional[WebSocketClientProtocol]
current_session_monitor: Optional[Future]

Expand All @@ -65,6 +67,7 @@ def __init__(
web_client: Optional[AsyncWebClient] = None,
auto_reconnect_enabled: bool = True,
ping_interval: float = 10,
trace_enabled: bool = False,
):
"""Socket Mode client
Expand All @@ -74,6 +77,7 @@ def __init__(
web_client: Web API client
auto_reconnect_enabled: True if automatic reconnection is enabled (default: True)
ping_interval: interval for ping-pong with Slack servers (seconds)
trace_enabled: True if more verbose logs to see what's happening under the hood
"""
self.app_token = app_token
self.logger = logger or logging.getLogger(__name__)
Expand All @@ -83,6 +87,7 @@ def __init__(
self.default_auto_reconnect_enabled = auto_reconnect_enabled
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
self.ping_interval = ping_interval
self.trace_enabled = trace_enabled
self.wss_uri = None
self.message_queue = Queue()
self.message_listeners = []
Expand Down

0 comments on commit ed8292d

Please sign in to comment.