diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 1630fb741c..b21cb66153 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -11,7 +11,6 @@ List, Mapping, Optional, - Tuple, Type, TypeVar, Union, @@ -517,14 +516,9 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No async def _determine_nodes( self, command: str, *args: Any, node_flag: Optional[str] = None - ) -> Tuple[List["ClusterNode"], bool]: - """Determine which nodes should be executed the command on - - Returns: - tuple[list[Type[ClusterNode]], bool]: - A tuple containing a list of target nodes and a bool indicating - if the return node was chosen because it is the default node - """ + ) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. if not node_flag: # get the nodes group for this command if it was predefined node_flag = self.command_flags.get(command) @@ -532,21 +526,19 @@ async def _determine_nodes( if node_flag in self.node_flags: if node_flag == self.__class__.DEFAULT_NODE: # return the cluster's default node - return [self.nodes_manager.default_node], True + return [self.nodes_manager.default_node] if node_flag == self.__class__.PRIMARIES: # return all primaries - return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) if node_flag == self.__class__.REPLICAS: # return all replicas - return self.nodes_manager.get_nodes_by_server_type(REPLICA), False + return self.nodes_manager.get_nodes_by_server_type(REPLICA) if node_flag == self.__class__.ALL_NODES: # return all nodes - return list(self.nodes_manager.nodes_cache.values()), False + return list(self.nodes_manager.nodes_cache.values()) if node_flag == self.__class__.RANDOM: # return a random node - return [ - random.choice(list(self.nodes_manager.nodes_cache.values())) - ], False + return [random.choice(list(self.nodes_manager.nodes_cache.values()))] # get the node that holds the key's slot return [ @@ -554,7 +546,7 @@ async def _determine_nodes( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, ) - ], False + ] async def _determine_slot(self, command: str, *args: Any) -> int: if self.command_flags.get(command) == SLOT_ID: @@ -671,13 +663,18 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: try: if not target_nodes_specified: # Determine the nodes to execute the command on - target_nodes, is_default_node = await self._determine_nodes( + target_nodes = await self._determine_nodes( *args, node_flag=passed_targets ) if not target_nodes: raise RedisClusterException( f"No targets were found to execute {args} command on" ) + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + is_default_node = True if len(target_nodes) == 1: # Return the processed result @@ -1456,7 +1453,7 @@ async def _execute( if passed_targets and not client._is_node_flag(passed_targets): target_nodes = client._parse_target_nodes(passed_targets) else: - target_nodes, is_default_node = await client._determine_nodes( + target_nodes = await client._determine_nodes( *cmd.args, node_flag=passed_targets ) if not target_nodes: @@ -1465,8 +1462,9 @@ async def _execute( ) if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") - node = target_nodes[0] + if node == client.get_default_node(): + is_default_node = True if node.name not in nodes: nodes[node.name] = (node, []) nodes[node.name][1].append(cmd) diff --git a/redis/cluster.py b/redis/cluster.py index 1e8bbf8583..0b2c4f1387 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -835,14 +835,9 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]: - """Determine which nodes should be executed the command on - - Returns: - tuple[list[Type[ClusterNode]], bool]: - A tuple containing a list of target nodes and a bool indicating - if the return node was chosen because it is the default node - """ + def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -856,28 +851,28 @@ def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]: command_flag = self.command_flags.get(command) if command_flag == self.__class__.RANDOM: # return a random node - return [self.get_random_node()], False + return [self.get_random_node()] elif command_flag == self.__class__.PRIMARIES: # return all primaries - return self.get_primaries(), False + return self.get_primaries() elif command_flag == self.__class__.REPLICAS: # return all replicas - return self.get_replicas(), False + return self.get_replicas() elif command_flag == self.__class__.ALL_NODES: # return all nodes - return self.get_nodes(), False + return self.get_nodes() elif command_flag == self.__class__.DEFAULT_NODE: # return the cluster's default node - return [self.nodes_manager.default_node], True + return [self.nodes_manager.default_node] elif command in self.__class__.SEARCH_COMMANDS[0]: - return [self.nodes_manager.default_node], True + return [self.nodes_manager.default_node] else: # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( slot, self.read_from_replicas and command in READ_COMMANDS ) - return [node], False + return [node] def _should_reinitialized(self): # To reinitialize the cluster on every MOVED error, @@ -1045,13 +1040,18 @@ def execute_command(self, *args, **kwargs): res = {} if not target_nodes_specified: # Determine the nodes to execute the command on - target_nodes, is_default_node = self._determine_nodes( + target_nodes = self._determine_nodes( *args, **kwargs, nodes_flag=passed_targets ) if not target_nodes: raise RedisClusterException( f"No targets were found to execute {args} command on" ) + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + is_default_node = True for node in target_nodes: res[node.name] = self._execute_command(node, *args, **kwargs) # Return the processed result @@ -1935,7 +1935,7 @@ def _send_cluster_commands( if passed_targets and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) else: - target_nodes, is_default_node = self._determine_nodes( + target_nodes = self._determine_nodes( *c.args, node_flag=passed_targets ) if not target_nodes: @@ -1948,6 +1948,8 @@ def _send_cluster_commands( ) node = target_nodes[0] + if node == self.get_default_node(): + is_default_node = True # now that we know the name of the node # ( it's just a string in the form of host:port )