Skip to content

Commit

Permalink
Changed async clusterPipeline default node to be replaced only if a r…
Browse files Browse the repository at this point in the history
…elevant exception was thrown from that specific node
  • Loading branch information
barshaul committed Nov 29, 2022
1 parent 4cfa069 commit 2388569
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
28 changes: 16 additions & 12 deletions redis/asyncio/cluster.py
Expand Up @@ -643,7 +643,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
command = args[0]
target_nodes = []
target_nodes_specified = False
is_default_node = False
retry_attempts = self.cluster_error_retry_attempts

passed_targets = kwargs.pop("target_nodes", None)
Expand All @@ -657,7 +656,10 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
for _ in range(execute_attempts):
if self._initialize:
await self.initialize()
if is_default_node:
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
# Replace the default cluster node
self.replace_default_node()
try:
Expand All @@ -670,11 +672,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
raise RedisClusterException(
f"No targets were found to execute {args} command on"
)
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
is_default_node = True

if len(target_nodes) == 1:
# Return the processed result
Expand Down Expand Up @@ -1447,7 +1444,6 @@ async def _execute(
]

nodes = {}
is_default_node = False
for cmd in todo:
passed_targets = cmd.kwargs.pop("target_nodes", None)
if passed_targets and not client._is_node_flag(passed_targets):
Expand All @@ -1463,8 +1459,6 @@ async def _execute(
if len(target_nodes) > 1:
raise RedisClusterException(f"Too many targets for command {cmd.args}")
node = target_nodes[0]
if node == client.get_default_node():
is_default_node = True
if node.name not in nodes:
nodes[node.name] = (node, [])
nodes[node.name][1].append(cmd)
Expand Down Expand Up @@ -1500,8 +1494,18 @@ async def _execute(
result.args = (msg,) + result.args[1:]
raise result

if is_default_node:
client.replace_default_node()
default_node = nodes.get(client.get_default_node().name)
if default_node is not None:
# This pipeline execution used the default node, check if we need
# to replace it.
# Note: when the error is raised we'll reset the default node in the
# caller function.
for cmd in default_node[1]:
# Check if it has a command that failed with a relevant
# exception
if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
client.replace_default_node()
break

return [cmd.result for cmd in stack]

Expand Down
19 changes: 19 additions & 0 deletions tests/test_asyncio/test_cluster.py
Expand Up @@ -2612,6 +2612,25 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None:
*(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)),
)

@pytest.mark.onlycluster
async def test_cluster_pipeline_with_default_node_error_command(self, r):
"""
Test that the default node is being replaced when it raises a relevant exception
"""
curr_default_node = r.get_default_node()
err = ConnectionError("error")
cmd_count = await r.command_count()
mock_node_resp_exc(curr_default_node, err)
async with r.pipeline(transaction=False) as pipe:
pipe.command_count()
result = await pipe.execute(raise_on_error=False)

assert result[0] == err
assert r.get_default_node() != curr_default_node
pipe.command_count()
result = await pipe.execute(raise_on_error=False)
assert result[0] == cmd_count


@pytest.mark.ssl
class TestSSL:
Expand Down

0 comments on commit 2388569

Please sign in to comment.