diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a4629f5399..e0e77c74ae 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -643,7 +643,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: command = args[0] target_nodes = [] target_nodes_specified = False - is_default_node = False retry_attempts = self.cluster_error_retry_attempts passed_targets = kwargs.pop("target_nodes", None) @@ -657,7 +656,10 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: for _ in range(execute_attempts): if self._initialize: await self.initialize() - if is_default_node: + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): # Replace the default cluster node self.replace_default_node() try: @@ -670,11 +672,6 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: raise RedisClusterException( f"No targets were found to execute {args} command on" ) - if ( - len(target_nodes) == 1 - and target_nodes[0] == self.get_default_node() - ): - is_default_node = True if len(target_nodes) == 1: # Return the processed result @@ -1447,7 +1444,6 @@ async def _execute( ] nodes = {} - is_default_node = False for cmd in todo: passed_targets = cmd.kwargs.pop("target_nodes", None) if passed_targets and not client._is_node_flag(passed_targets): @@ -1463,8 +1459,6 @@ async def _execute( if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") node = target_nodes[0] - if node == client.get_default_node(): - is_default_node = True if node.name not in nodes: nodes[node.name] = (node, []) nodes[node.name][1].append(cmd) @@ -1500,8 +1494,18 @@ async def _execute( result.args = (msg,) + result.args[1:] raise result - if is_default_node: - client.replace_default_node() + default_node = nodes.get(client.get_default_node().name) + if default_node is not None: + # This pipeline execution used the default node, check if we need + # to replace it. + # Note: when the error is raised we'll reset the default node in the + # caller function. + for cmd in default_node[1]: + # Check if it has a command that failed with a relevant + # exception + if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + break return [cmd.result for cmd in stack] diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index cec2dc09a4..02efe1234e 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -2612,6 +2612,25 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: *(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)), ) + @pytest.mark.onlycluster + async def test_cluster_pipeline_with_default_node_error_command(self, r): + """ + Test that the default node is being replaced when it raises a relevant exception + """ + curr_default_node = r.get_default_node() + err = ConnectionError("error") + cmd_count = await r.command_count() + mock_node_resp_exc(curr_default_node, err) + async with r.pipeline(transaction=False) as pipe: + pipe.command_count() + result = await pipe.execute(raise_on_error=False) + + assert result[0] == err + assert r.get_default_node() != curr_default_node + pipe.command_count() + result = await pipe.execute(raise_on_error=False) + assert result[0] == cmd_count + @pytest.mark.ssl class TestSSL: