From ed8292de1d17a335839b14cb51c9dd46ddfa577f Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 9 Sep 2021 23:35:16 +0900 Subject: [PATCH] Fix #1110 Socket Mode disconnection issue with the aiohttp-based client --- .../samples/socket_mode/aiohttp_example.py | 1 + slack_sdk/socket_mode/aiohttp/__init__.py | 188 ++++++++++++------ slack_sdk/socket_mode/async_client.py | 5 + slack_sdk/socket_mode/websockets/__init__.py | 5 + 4 files changed, 136 insertions(+), 63 deletions(-) diff --git a/integration_tests/samples/socket_mode/aiohttp_example.py b/integration_tests/samples/socket_mode/aiohttp_example.py index eaa7d152b..3c16bad2a 100644 --- a/integration_tests/samples/socket_mode/aiohttp_example.py +++ b/integration_tests/samples/socket_mode/aiohttp_example.py @@ -17,6 +17,7 @@ async def main(): web_client=AsyncWebClient( token=os.environ.get("SLACK_SDK_TEST_SOCKET_MODE_BOT_TOKEN") ), + trace_enabled=True, ) async def process(client: SocketModeClient, req: SocketModeRequest): diff --git a/slack_sdk/socket_mode/aiohttp/__init__.py b/slack_sdk/socket_mode/aiohttp/__init__.py index 36e538f34..75e882f10 100644 --- a/slack_sdk/socket_mode/aiohttp/__init__.py +++ b/slack_sdk/socket_mode/aiohttp/__init__.py @@ -1,4 +1,4 @@ -"""aiohttp bassd Socket Mode client +"""aiohttp based Socket Mode client * https://api.slack.com/apis/connections/socket * https://slack.dev/python-slack-sdk/socket-mode/ @@ -7,7 +7,8 @@ """ import asyncio import logging -from asyncio import Future, Lock +import time +from asyncio import Future, Lock, Task from asyncio import Queue from logging import Logger from typing import Union, Optional, List, Callable, Awaitable @@ -52,12 +53,16 @@ class SocketModeClient(AsyncBaseSocketModeClient): proxy: Optional[str] ping_interval: float + trace_enabled: bool + + last_ping_pong_time: Optional[float] current_session: Optional[ClientWebSocketResponse] current_session_monitor: Optional[Future] auto_reconnect_enabled: bool default_auto_reconnect_enabled: bool closed: bool + stale: bool connect_operation_lock: Lock on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]] @@ -71,7 +76,8 @@ def __init__( web_client: Optional[AsyncWebClient] = None, proxy: Optional[str] = None, auto_reconnect_enabled: bool = True, - ping_interval: float = 10, + ping_interval: float = 5, + trace_enabled: bool = False, on_message_listeners: Optional[List[Callable[[WSMessage], None]]] = None, on_error_listeners: Optional[List[Callable[[WSMessage], None]]] = None, on_close_listeners: Optional[List[Callable[[WSMessage], None]]] = None, @@ -84,6 +90,7 @@ def __init__( web_client: Web API client auto_reconnect_enabled: True if automatic reconnection is enabled (default: True) ping_interval: interval for ping-pong with Slack servers (seconds) + trace_enabled: True if more verbose logs to see what's happening under the hood proxy: the HTTP proxy URL on_message_listeners: listener functions for on_message on_error_listeners: listener functions for on_error @@ -93,6 +100,7 @@ def __init__( self.logger = logger or logging.getLogger(__name__) self.web_client = web_client or AsyncWebClient() self.closed = False + self.stale = False self.connect_operation_lock = Lock() self.proxy = proxy if self.proxy is None or len(self.proxy.strip()) == 0: @@ -103,6 +111,8 @@ def __init__( self.default_auto_reconnect_enabled = auto_reconnect_enabled self.auto_reconnect_enabled = self.default_auto_reconnect_enabled self.ping_interval = ping_interval + self.trace_enabled = trace_enabled + self.last_ping_pong_time = None self.wss_uri = None self.message_queue = Queue() @@ -126,70 +136,116 @@ def __init__( self.message_processor = asyncio.ensure_future(self.process_messages()) async def monitor_current_session(self) -> None: - while not self.closed: - await asyncio.sleep(self.ping_interval) - try: - if self.auto_reconnect_enabled and ( - self.current_session is None or self.current_session.closed - ): - self.logger.info( - "The session seems to be already closed. Going to reconnect..." + try: + while not self.closed: + try: + await asyncio.sleep(self.ping_interval) + if self.current_session is not None: + t = time.time() + if self.last_ping_pong_time is None: + self.last_ping_pong_time = float(t) + await self.current_session.ping(f"ping-pong:{t}") + + if self.auto_reconnect_enabled: + should_reconnect = False + if self.current_session is None or self.current_session.closed: + self.logger.info( + "The session seems to be already closed. Going to reconnect..." + ) + should_reconnect = True + + if self.last_ping_pong_time is not None: + disconnected_seconds = int(time.time() - self.last_ping_pong_time) + if disconnected_seconds >= (self.ping_interval * 4): + self.logger.info( + "The connection seems to be stale. Disconnecting..." + f" reason: disconnected for {disconnected_seconds}+ seconds)" + ) + self.stale = True + self.last_ping_pong_time = None + should_reconnect = True + + if should_reconnect is True or not await self.is_connected(): + await self.connect_to_new_endpoint() + self.logger.info("Reconnection done.") + + except Exception as e: + self.logger.error( + "Failed to check the current session or reconnect to the server " + f"(error: {type(e).__name__}, message: {e})" ) - await self.connect_to_new_endpoint() - except Exception as e: - self.logger.error( - "Failed to check the current session or reconnect to the server " - f"(error: {type(e).__name__}, message: {e})" - ) + except asyncio.CancelledError: + if self.trace_enabled: + self.logger.debug("The running monitor_current_session task is now cancelled") + raise async def receive_messages(self) -> None: - consecutive_error_count = 0 - while not self.closed: - try: - message: WSMessage = await self.current_session.receive() - if self.logger.level <= logging.DEBUG: - type = WSMsgType(message.type) - message_type = type.name if type is not None else message.type - message_data = message.data - if isinstance(message_data, bytes): - message_data = message_data.decode("utf-8") - self.logger.debug( - f"Received message (type: {message_type}, data: {message_data}, extra: {message.extra})" - ) - if message is not None: - if message.type == WSMsgType.TEXT: + try: + consecutive_error_count = 0 + while not self.closed: + try: + message: WSMessage = await self.current_session.receive() + if self.trace_enabled and self.logger.level <= logging.DEBUG: + type = WSMsgType(message.type) + message_type = type.name if type is not None else message.type message_data = message.data - await self.enqueue_message(message_data) - for listener in self.on_message_listeners: - await listener(message) - elif message.type == WSMsgType.CLOSE: - if self.auto_reconnect_enabled: - self.logger.info( - "Received CLOSE event. Going to reconnect..." + if isinstance(message_data, bytes): + message_data = message_data.decode("utf-8") + if len(message_data) > 0: + # To skip the empty message that Slack server-side often sends + self.logger.debug( + f"Received message (type: {message_type}, data: {message_data}, extra: {message.extra})" ) - await self.connect_to_new_endpoint() - for listener in self.on_close_listeners: - await listener(message) - elif message.type == WSMsgType.ERROR: - for listener in self.on_error_listeners: - await listener(message) - elif message.type == WSMsgType.CLOSED: + if message is not None: + if message.type == WSMsgType.TEXT: + message_data = message.data + await self.enqueue_message(message_data) + for listener in self.on_message_listeners: + await listener(message) + elif message.type == WSMsgType.CLOSE: + if self.auto_reconnect_enabled: + self.logger.info( + "Received CLOSE event. Going to reconnect..." + ) + await self.connect_to_new_endpoint() + for listener in self.on_close_listeners: + await listener(message) + elif message.type == WSMsgType.ERROR: + for listener in self.on_error_listeners: + await listener(message) + elif message.type == WSMsgType.CLOSED: + await asyncio.sleep(self.ping_interval) + continue + elif message.type == WSMsgType.PING: + await self.current_session.pong(message.data) + continue + elif message.type == WSMsgType.PONG: + elements = message.data.decode('utf-8').split(":") + if len(elements) == 2: + try: + self.last_ping_pong_time = float(elements[1]) + except: + pass + continue + consecutive_error_count = 0 + except Exception as e: + consecutive_error_count += 1 + self.logger.error( + f"Failed to receive or enqueue a message: {type(e).__name__}, {e}" + ) + if isinstance(e, ClientConnectionError): await asyncio.sleep(self.ping_interval) - continue - consecutive_error_count = 0 - except Exception as e: - consecutive_error_count += 1 - self.logger.error( - f"Failed to receive or enqueue a message: {type(e).__name__}, {e}" - ) - if isinstance(e, ClientConnectionError): - await asyncio.sleep(self.ping_interval) - else: - await asyncio.sleep(consecutive_error_count) + else: + await asyncio.sleep(consecutive_error_count) + except asyncio.CancelledError: + if self.trace_enabled: + self.logger.debug("The running receive_messages task is now cancelled") + raise async def is_connected(self) -> bool: return ( not self.closed + and not self.stale and self.current_session is not None and not self.current_session.closed ) @@ -200,19 +256,25 @@ async def connect(self): self.wss_uri = await self.issue_new_wss_url() self.current_session = await self.aiohttp_client_session.ws_connect( self.wss_uri, + autoping=False, heartbeat=self.ping_interval, proxy=self.proxy, ) self.auto_reconnect_enabled = self.default_auto_reconnect_enabled + self.stale = False self.logger.info("A new session has been established") - if self.current_session_monitor is None: - self.current_session_monitor = asyncio.ensure_future( - self.monitor_current_session() - ) + if self.current_session_monitor is not None: + self.current_session_monitor.cancel() + + self.current_session_monitor = asyncio.ensure_future( + self.monitor_current_session() + ) + + if self.message_receiver is not None: + self.message_receiver.cancel() - if self.message_receiver is None: - self.message_receiver = asyncio.ensure_future(self.receive_messages()) + self.message_receiver = asyncio.ensure_future(self.receive_messages()) if old_session is not None: await old_session.close() diff --git a/slack_sdk/socket_mode/async_client.py b/slack_sdk/socket_mode/async_client.py index ac58dc220..7e9f9b6f4 100644 --- a/slack_sdk/socket_mode/async_client.py +++ b/slack_sdk/socket_mode/async_client.py @@ -22,6 +22,7 @@ class AsyncBaseSocketModeClient: app_token: str wss_uri: str auto_reconnect_enabled: bool + trace_enabled: bool closed: bool connect_operation_lock: Lock @@ -72,12 +73,16 @@ async def disconnect(self): async def connect_to_new_endpoint(self, force: bool = False): try: await self.connect_operation_lock.acquire() + if self.trace_enabled: + self.logger.debug("For reconnection, the connect_operation_lock was acquired") if force or not await self.is_connected(): self.wss_uri = await self.issue_new_wss_url() await self.connect() finally: if self.connect_operation_lock.locked() is True: self.connect_operation_lock.release() + if self.trace_enabled: + self.logger.debug("The connect_operation_lock for reconnection was released") async def close(self): self.closed = True diff --git a/slack_sdk/socket_mode/websockets/__init__.py b/slack_sdk/socket_mode/websockets/__init__.py index 205c952d7..f6279c0ae 100644 --- a/slack_sdk/socket_mode/websockets/__init__.py +++ b/slack_sdk/socket_mode/websockets/__init__.py @@ -50,6 +50,8 @@ class SocketModeClient(AsyncBaseSocketModeClient): message_processor: Future ping_interval: float + trace_enabled: bool + current_session: Optional[WebSocketClientProtocol] current_session_monitor: Optional[Future] @@ -65,6 +67,7 @@ def __init__( web_client: Optional[AsyncWebClient] = None, auto_reconnect_enabled: bool = True, ping_interval: float = 10, + trace_enabled: bool = False, ): """Socket Mode client @@ -74,6 +77,7 @@ def __init__( web_client: Web API client auto_reconnect_enabled: True if automatic reconnection is enabled (default: True) ping_interval: interval for ping-pong with Slack servers (seconds) + trace_enabled: True if more verbose logs to see what's happening under the hood """ self.app_token = app_token self.logger = logger or logging.getLogger(__name__) @@ -83,6 +87,7 @@ def __init__( self.default_auto_reconnect_enabled = auto_reconnect_enabled self.auto_reconnect_enabled = self.default_auto_reconnect_enabled self.ping_interval = ping_interval + self.trace_enabled = trace_enabled self.wss_uri = None self.message_queue = Queue() self.message_listeners = []