Skip to content

Commit

Permalink
Changed determine_nodes to return only the target nodes, added a comp…
Browse files Browse the repository at this point in the history
…arison to determine whether a node is the default node instead
  • Loading branch information
barshaul committed Nov 27, 2022
1 parent b3ab42a commit 23282a1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
42 changes: 20 additions & 22 deletions redis/asyncio/cluster.py
Expand Up @@ -11,7 +11,6 @@
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -223,7 +222,7 @@ def __init__(
reinitialize_steps: int = 5,
cluster_error_retry_attempts: int = 3,
connection_error_retry_attempts: int = 3,
max_connections: int = 2**31,
max_connections: int = 2 ** 31,
# Client related kwargs
db: Union[str, int] = 0,
path: Optional[str] = None,
Expand Down Expand Up @@ -517,44 +516,37 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No

async def _determine_nodes(
self, command: str, *args: Any, node_flag: Optional[str] = None
) -> Tuple[List["ClusterNode"], bool]:
"""Determine which nodes should be executed the command on
Returns:
tuple[list[Type[ClusterNode]], bool]:
A tuple containing a list of target nodes and a bool indicating
if the return node was chosen because it is the default node
"""
) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
if not node_flag:
# get the nodes group for this command if it was predefined
node_flag = self.command_flags.get(command)

if node_flag in self.node_flags:
if node_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node], True
return [self.nodes_manager.default_node]
if node_flag == self.__class__.PRIMARIES:
# return all primaries
return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False
return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
if node_flag == self.__class__.REPLICAS:
# return all replicas
return self.nodes_manager.get_nodes_by_server_type(REPLICA), False
return self.nodes_manager.get_nodes_by_server_type(REPLICA)
if node_flag == self.__class__.ALL_NODES:
# return all nodes
return list(self.nodes_manager.nodes_cache.values()), False
return list(self.nodes_manager.nodes_cache.values())
if node_flag == self.__class__.RANDOM:
# return a random node
return [
random.choice(list(self.nodes_manager.nodes_cache.values()))
], False
return [random.choice(list(self.nodes_manager.nodes_cache.values()))]

# get the node that holds the key's slot
return [
self.nodes_manager.get_node_from_slot(
await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
)
], False
]

async def _determine_slot(self, command: str, *args: Any) -> int:
if self.command_flags.get(command) == SLOT_ID:
Expand Down Expand Up @@ -671,13 +663,18 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
try:
if not target_nodes_specified:
# Determine the nodes to execute the command on
target_nodes, is_default_node = await self._determine_nodes(
target_nodes = await self._determine_nodes(
*args, node_flag=passed_targets
)
if not target_nodes:
raise RedisClusterException(
f"No targets were found to execute {args} command on"
)
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
is_default_node = True

if len(target_nodes) == 1:
# Return the processed result
Expand Down Expand Up @@ -896,7 +893,7 @@ def __init__(
port: Union[str, int],
server_type: Optional[str] = None,
*,
max_connections: int = 2**31,
max_connections: int = 2 ** 31,
connection_class: Type[Connection] = Connection,
**connection_kwargs: Any,
) -> None:
Expand Down Expand Up @@ -1456,7 +1453,7 @@ async def _execute(
if passed_targets and not client._is_node_flag(passed_targets):
target_nodes = client._parse_target_nodes(passed_targets)
else:
target_nodes, is_default_node = await client._determine_nodes(
target_nodes = await client._determine_nodes(
*cmd.args, node_flag=passed_targets
)
if not target_nodes:
Expand All @@ -1465,8 +1462,9 @@ async def _execute(
)
if len(target_nodes) > 1:
raise RedisClusterException(f"Too many targets for command {cmd.args}")

node = target_nodes[0]
if node == client.get_default_node():
is_default_node = True
if node.name not in nodes:
nodes[node.name] = (node, [])
nodes[node.name][1].append(cmd)
Expand Down
36 changes: 19 additions & 17 deletions redis/cluster.py
Expand Up @@ -835,14 +835,9 @@ def set_response_callback(self, command, callback):
"""Set a custom Response Callback"""
self.cluster_response_callbacks[command] = callback

def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]:
"""Determine which nodes should be executed the command on
Returns:
tuple[list[Type[ClusterNode]], bool]:
A tuple containing a list of target nodes and a bool indicating
if the return node was chosen because it is the default node
"""
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()
Expand All @@ -856,28 +851,28 @@ def _determine_nodes(self, *args, **kwargs) -> Tuple[List["ClusterNode"], bool]:
command_flag = self.command_flags.get(command)
if command_flag == self.__class__.RANDOM:
# return a random node
return [self.get_random_node()], False
return [self.get_random_node()]
elif command_flag == self.__class__.PRIMARIES:
# return all primaries
return self.get_primaries(), False
return self.get_primaries()
elif command_flag == self.__class__.REPLICAS:
# return all replicas
return self.get_replicas(), False
return self.get_replicas()
elif command_flag == self.__class__.ALL_NODES:
# return all nodes
return self.get_nodes(), False
return self.get_nodes()
elif command_flag == self.__class__.DEFAULT_NODE:
# return the cluster's default node
return [self.nodes_manager.default_node], True
return [self.nodes_manager.default_node]
elif command in self.__class__.SEARCH_COMMANDS[0]:
return [self.nodes_manager.default_node], True
return [self.nodes_manager.default_node]
else:
# get the node that holds the key's slot
slot = self.determine_slot(*args)
node = self.nodes_manager.get_node_from_slot(
slot, self.read_from_replicas and command in READ_COMMANDS
)
return [node], False
return [node]

def _should_reinitialized(self):
# To reinitialize the cluster on every MOVED error,
Expand Down Expand Up @@ -1045,13 +1040,18 @@ def execute_command(self, *args, **kwargs):
res = {}
if not target_nodes_specified:
# Determine the nodes to execute the command on
target_nodes, is_default_node = self._determine_nodes(
target_nodes = self._determine_nodes(
*args, **kwargs, nodes_flag=passed_targets
)
if not target_nodes:
raise RedisClusterException(
f"No targets were found to execute {args} command on"
)
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
is_default_node = True
for node in target_nodes:
res[node.name] = self._execute_command(node, *args, **kwargs)
# Return the processed result
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def _send_cluster_commands(
if passed_targets and not self._is_nodes_flag(passed_targets):
target_nodes = self._parse_target_nodes(passed_targets)
else:
target_nodes, is_default_node = self._determine_nodes(
target_nodes = self._determine_nodes(
*c.args, node_flag=passed_targets
)
if not target_nodes:
Expand All @@ -1948,6 +1948,8 @@ def _send_cluster_commands(
)

node = target_nodes[0]
if node == self.get_default_node():
is_default_node = True

# now that we know the name of the node
# ( it's just a string in the form of host:port )
Expand Down

0 comments on commit 23282a1

Please sign in to comment.