From 1c066e2c91768e2e85b1c36d8a7e3290369f797e Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 20 Nov 2022 09:52:14 +0200 Subject: [PATCH] Added a replacement for the default cluster node in the event of failure. Handles failovers better. --- CHANGES | 1 + redis/asyncio/cluster.py | 37 +++++++++++++----- redis/cluster.py | 61 ++++++++++++++++++++++++------ tests/test_asyncio/test_cluster.py | 20 ++++++++++ tests/test_cluster.py | 23 +++++++++++ 5 files changed, 121 insertions(+), 21 deletions(-) diff --git a/CHANGES b/CHANGES index 883c548f38..120af7c6d2 100644 --- a/CHANGES +++ b/CHANGES @@ -28,6 +28,7 @@ * Fixed "cannot pickle '_thread.lock' object" bug (#2354, #2297) * Added CredentialsProvider class to support password rotation * Enable Lock for asyncio cluster mode + * Added a replacement for the default cluster node in the event of failure (#2463) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index d5a38b2878..dd06dc02f0 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -222,7 +222,7 @@ def __init__( reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 3, - max_connections: int = 2**31, + max_connections: int = 2 ** 31, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, @@ -516,7 +516,14 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No async def _determine_nodes( self, command: str, *args: Any, node_flag: Optional[str] = None - ) -> List["ClusterNode"]: + ) -> tuple[list["ClusterNode"], bool]: + """Determine which nodes should be executed the command on + + Returns: + tuple[list[Type[ClusterNode]], bool]: + A tuple containing a list of target nodes and a bool indicating + if the return node was chosen because it is the default node + """ if not node_flag: # get the nodes group for this command if it was predefined node_flag = self.command_flags.get(command) @@ -524,19 +531,21 @@ async def _determine_nodes( if node_flag in self.node_flags: if node_flag == self.__class__.DEFAULT_NODE: # return the cluster's default node - return [self.nodes_manager.default_node] + return [self.nodes_manager.default_node], True if node_flag == self.__class__.PRIMARIES: # return all primaries - return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False if node_flag == self.__class__.REPLICAS: # return all replicas - return self.nodes_manager.get_nodes_by_server_type(REPLICA) + return self.nodes_manager.get_nodes_by_server_type(REPLICA), False if node_flag == self.__class__.ALL_NODES: # return all nodes - return list(self.nodes_manager.nodes_cache.values()) + return list(self.nodes_manager.nodes_cache.values()), False if node_flag == self.__class__.RANDOM: # return a random node - return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + return [ + random.choice(list(self.nodes_manager.nodes_cache.values())) + ], False # get the node that holds the key's slot return [ @@ -544,7 +553,7 @@ async def _determine_nodes( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, ) - ] + ], False async def _determine_slot(self, command: str, *args: Any) -> int: if self.command_flags.get(command) == SLOT_ID: @@ -641,6 +650,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: command = args[0] target_nodes = [] target_nodes_specified = False + is_default_node = False retry_attempts = self.cluster_error_retry_attempts passed_targets = kwargs.pop("target_nodes", None) @@ -654,10 +664,13 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: for _ in range(execute_attempts): if self._initialize: await self.initialize() + if is_default_node: + # Replace the default cluster node + self.replace_default_node() try: if not target_nodes_specified: # Determine the nodes to execute the command on - target_nodes = await self._determine_nodes( + target_nodes, is_default_node = await self._determine_nodes( *args, node_flag=passed_targets ) if not target_nodes: @@ -1436,12 +1449,13 @@ async def _execute( ] nodes = {} + is_default_node = False for cmd in todo: passed_targets = cmd.kwargs.pop("target_nodes", None) if passed_targets and not client._is_node_flag(passed_targets): target_nodes = client._parse_target_nodes(passed_targets) else: - target_nodes = await client._determine_nodes( + target_nodes, is_default_node = await client._determine_nodes( *cmd.args, node_flag=passed_targets ) if not target_nodes: @@ -1487,6 +1501,9 @@ async def _execute( result.args = (msg,) + result.args[1:] raise result + if is_default_node: + self.replace_default_node() + return [cmd.result for cmd in stack] def _split_command_across_slots( diff --git a/redis/cluster.py b/redis/cluster.py index 91deaead59..56c354b814 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -379,6 +379,30 @@ class AbstractRedisCluster: ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + def replace_default_node(self, target_node: "ClusterNode" = None) -> None: + """Replace the default cluster node. + A random cluster node will be chosen if target_node isn't passed, and primaries + will be prioritized. The default node will not be changed if there are no other + nodes in the cluster. + + Args: + target_node (ClusterNode, optional): Target node to replace the default + node. Defaults to None. + """ + if target_node: + self.nodes_manager.default_node = target_node + else: + curr_node = self.get_default_node() + primaries = [node for node in self.get_primaries() if node != curr_node] + if primaries: + # Choose a primary if the cluster contains different primaries + self.nodes_manager.default_node = random.choice(primaries) + else: + # Otherwise, hoose a primary if the cluster contains different primaries + replicas = [node for node in self.get_replicas() if node != curr_node] + if replicas: + self.nodes_manager.default_node = random.choice(replicas) + class RedisCluster(AbstractRedisCluster, RedisClusterCommands): @classmethod @@ -811,7 +835,14 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs): + def _determine_nodes(self, *args, **kwargs) -> tuple[list["ClusterNode"], bool]: + """Determine which nodes should be executed the command on + + Returns: + tuple[list[Type[ClusterNode]], bool]: + A tuple containing a list of target nodes and a bool indicating + if the return node was chosen because it is the default node + """ command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -825,28 +856,28 @@ def _determine_nodes(self, *args, **kwargs): command_flag = self.command_flags.get(command) if command_flag == self.__class__.RANDOM: # return a random node - return [self.get_random_node()] + return [self.get_random_node()], False elif command_flag == self.__class__.PRIMARIES: # return all primaries - return self.get_primaries() + return self.get_primaries(), False elif command_flag == self.__class__.REPLICAS: # return all replicas - return self.get_replicas() + return self.get_replicas(), False elif command_flag == self.__class__.ALL_NODES: # return all nodes - return self.get_nodes() + return self.get_nodes(), False elif command_flag == self.__class__.DEFAULT_NODE: # return the cluster's default node - return [self.nodes_manager.default_node] + return [self.nodes_manager.default_node], True elif command in self.__class__.SEARCH_COMMANDS[0]: - return [self.nodes_manager.default_node] + return [self.nodes_manager.default_node], True else: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) - return [node] + return [node], False def _should_reinitialized(self): # To reinitialize the cluster on every MOVED error, @@ -990,6 +1021,7 @@ def execute_command(self, *args, **kwargs): dict """ target_nodes_specified = False + is_default_node = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) if passed_targets is not None and not self._is_nodes_flag(passed_targets): @@ -1013,7 +1045,7 @@ def execute_command(self, *args, **kwargs): res = {} if not target_nodes_specified: # Determine the nodes to execute the command on - target_nodes = self._determine_nodes( + target_nodes, is_default_node = self._determine_nodes( *args, **kwargs, nodes_flag=passed_targets ) if not target_nodes: @@ -1025,6 +1057,9 @@ def execute_command(self, *args, **kwargs): # Return the processed result return self._process_result(args[0], res, **kwargs) except Exception as e: + if is_default_node: + # Replace the default cluster node + self.replace_default_node() if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. @@ -1883,7 +1918,7 @@ def _send_cluster_commands( # if we have to run through it again, we only retry # the commands that failed. attempt = sorted(stack, key=lambda x: x.position) - + is_default_node = False # build a list of node objects based on node names we need to nodes = {} @@ -1900,7 +1935,7 @@ def _send_cluster_commands( if passed_targets and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) else: - target_nodes = self._determine_nodes( + target_nodes, is_default_node = self._determine_nodes( *c.args, node_flag=passed_targets ) if not target_nodes: @@ -1926,6 +1961,8 @@ def _send_cluster_commands( # Connection retries are being handled in the node's # Retry object. Reinitialize the node -> slot table. self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, @@ -2007,6 +2044,8 @@ def _send_cluster_commands( self.reinitialize_counter += 1 if self._should_reinitialized(): self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() for c in attempt: try: # send each command individually like we diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 38bcaf6c00..f2d29d3d41 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -788,6 +788,26 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non ) await rc.close() + def test_replace_cluster_node(self, r: RedisCluster) -> None: + prev_default_node = r.get_default_node() + r.replace_default_node() + assert r.get_default_node() != prev_default_node + r.replace_default_node(prev_default_node) + assert r.get_default_node() == prev_default_node + + async def test_default_node_is_replaced_after_exception(self, r): + curr_default_node = r.get_default_node() + # CLUSTER NODES command is being executed on the default node + nodes = await r.cluster_nodes() + assert "myself" in nodes.get(curr_default_node.name).get("flags") + + # Mock connection error for the default node + mock_node_resp_exc(curr_default_node, ConnectionError("error")) + # Test that the command succeed from a different node + nodes = await r.cluster_nodes() + assert "myself" not in nodes.get(curr_default_node.name).get("flags") + assert r.get_default_node() != curr_default_node + class TestClusterRedisCommands: """ diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d18fbbbb33..43aeb9e045 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None: == retry._retries ) + def test_replace_cluster_node(self, r) -> None: + prev_default_node = r.get_default_node() + r.replace_default_node() + assert r.get_default_node() != prev_default_node + r.replace_default_node(prev_default_node) + assert r.get_default_node() == prev_default_node + + def test_default_node_is_replaced_after_exception(self, r): + curr_default_node = r.get_default_node() + # CLUSTER NODES command is being executed on the default node + nodes = r.cluster_nodes() + assert "myself" in nodes.get(curr_default_node.name).get("flags") + + def raise_connection_error(): + raise ConnectionError("error") + + # Mock connection error for the default node + mock_node_resp_func(curr_default_node, raise_connection_error) + # Test that the command succeed from a different node + nodes = r.cluster_nodes() + assert "myself" not in nodes.get(curr_default_node.name).get("flags") + assert r.get_default_node() != curr_default_node + @pytest.mark.onlycluster class TestClusterRedisCommands: