diff --git a/CHANGES b/CHANGES index c0b183e8e3..b16fbce464 100644 --- a/CHANGES +++ b/CHANGES @@ -29,6 +29,7 @@ * Added CredentialsProvider class to support password rotation * Enable Lock for asyncio cluster mode * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) + * Added a replacement for the default cluster node in the event of failure (#2463) * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index e0ed85eb8f..abe7d67463 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -354,6 +354,7 @@ def lock( name: KeyT, timeout: Optional[float] = None, sleep: float = 0.1, + blocking: bool = True, blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, @@ -369,6 +370,12 @@ def lock( when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -411,6 +418,7 @@ def lock( name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, ) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index d5a38b2878..ac61314262 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -517,6 +517,8 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No async def _determine_nodes( self, command: str, *args: Any, node_flag: Optional[str] = None ) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. if not node_flag: # get the nodes group for this command if it was predefined node_flag = self.command_flags.get(command) @@ -654,6 +656,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: for _ in range(execute_attempts): if self._initialize: await self.initialize() + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + # Replace the default cluster node + self.replace_default_node() try: if not target_nodes_specified: # Determine the nodes to execute the command on @@ -793,6 +801,7 @@ def lock( name: KeyT, timeout: Optional[float] = None, sleep: float = 0.1, + blocking: bool = True, blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, @@ -808,6 +817,12 @@ def lock( when the lock is in blocking mode and another client is currently holding the lock. + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a @@ -850,6 +865,7 @@ def lock( name, timeout=timeout, sleep=sleep, + blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, ) @@ -1450,7 +1466,6 @@ async def _execute( ) if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") - node = target_nodes[0] if node.name not in nodes: nodes[node.name] = (node, []) @@ -1487,6 +1502,19 @@ async def _execute( result.args = (msg,) + result.args[1:] raise result + default_node = nodes.get(client.get_default_node().name) + if default_node is not None: + # This pipeline execution used the default node, check if we need + # to replace it. + # Note: when the error is raised we'll reset the default node in the + # caller function. + for cmd in default_node[1]: + # Check if it has a command that failed with a relevant + # exception + if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + break + return [cmd.result for cmd in stack] def _split_command_across_slots( diff --git a/redis/cluster.py b/redis/cluster.py index 91deaead59..0b2c4f1387 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -379,6 +379,30 @@ class AbstractRedisCluster: ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) + def replace_default_node(self, target_node: "ClusterNode" = None) -> None: + """Replace the default cluster node. + A random cluster node will be chosen if target_node isn't passed, and primaries + will be prioritized. The default node will not be changed if there are no other + nodes in the cluster. + + Args: + target_node (ClusterNode, optional): Target node to replace the default + node. Defaults to None. + """ + if target_node: + self.nodes_manager.default_node = target_node + else: + curr_node = self.get_default_node() + primaries = [node for node in self.get_primaries() if node != curr_node] + if primaries: + # Choose a primary if the cluster contains different primaries + self.nodes_manager.default_node = random.choice(primaries) + else: + # Otherwise, hoose a primary if the cluster contains different primaries + replicas = [node for node in self.get_replicas() if node != curr_node] + if replicas: + self.nodes_manager.default_node = random.choice(replicas) + class RedisCluster(AbstractRedisCluster, RedisClusterCommands): @classmethod @@ -811,7 +835,9 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs): + def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + # Determine which nodes should be executed the command on. + # Returns a list of target nodes. command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -990,6 +1016,7 @@ def execute_command(self, *args, **kwargs): dict """ target_nodes_specified = False + is_default_node = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) if passed_targets is not None and not self._is_nodes_flag(passed_targets): @@ -1020,12 +1047,20 @@ def execute_command(self, *args, **kwargs): raise RedisClusterException( f"No targets were found to execute {args} command on" ) + if ( + len(target_nodes) == 1 + and target_nodes[0] == self.get_default_node() + ): + is_default_node = True for node in target_nodes: res[node.name] = self._execute_command(node, *args, **kwargs) # Return the processed result return self._process_result(args[0], res, **kwargs) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: + if is_default_node: + # Replace the default cluster node + self.replace_default_node() # The nodes and slots cache were reinitialized. # Try again with the new cluster setup. retry_attempts -= 1 @@ -1883,7 +1918,7 @@ def _send_cluster_commands( # if we have to run through it again, we only retry # the commands that failed. attempt = sorted(stack, key=lambda x: x.position) - + is_default_node = False # build a list of node objects based on node names we need to nodes = {} @@ -1913,6 +1948,8 @@ def _send_cluster_commands( ) node = target_nodes[0] + if node == self.get_default_node(): + is_default_node = True # now that we know the name of the node # ( it's just a string in the form of host:port ) @@ -1926,6 +1963,8 @@ def _send_cluster_commands( # Connection retries are being handled in the node's # Retry object. Reinitialize the node -> slot table. self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, @@ -2007,6 +2046,8 @@ def _send_cluster_commands( self.reinitialize_counter += 1 if self._should_reinitialized(): self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() for c in attempt: try: # send each command individually like we diff --git a/redis/commands/core.py b/redis/commands/core.py index 3be2823a7a..1625e10d9c 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5581,7 +5581,7 @@ def _geosearchgeneric( "GEOSEARCH member and longitude or latitude" " cant be set together" ) pieces.extend([b"FROMMEMBER", kwargs["member"]]) - if kwargs["longitude"] and kwargs["latitude"]: + if kwargs["longitude"] is not None and kwargs["latitude"] is not None: pieces.extend([b"FROMLONLAT", kwargs["longitude"], kwargs["latitude"]]) # BYRADIUS or BYBOX diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 38bcaf6c00..1997c9520b 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -788,6 +788,27 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non ) await rc.close() + def test_replace_cluster_node(self, r: RedisCluster) -> None: + prev_default_node = r.get_default_node() + r.replace_default_node() + assert r.get_default_node() != prev_default_node + r.replace_default_node(prev_default_node) + assert r.get_default_node() == prev_default_node + + async def test_default_node_is_replaced_after_exception(self, r): + curr_default_node = r.get_default_node() + # CLUSTER NODES command is being executed on the default node + nodes = await r.cluster_nodes() + assert "myself" in nodes.get(curr_default_node.name).get("flags") + # Mock connection error for the default node + mock_node_resp_exc(curr_default_node, ConnectionError("error")) + # Test that the command succeed from a different node + nodes = await r.cluster_nodes() + assert "myself" not in nodes.get(curr_default_node.name).get("flags") + assert r.get_default_node() != curr_default_node + # Rollback to the old default node + r.replace_default_node(curr_default_node) + class TestClusterRedisCommands: """ @@ -2591,6 +2612,25 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None: *(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)), ) + @pytest.mark.onlycluster + async def test_pipeline_with_default_node_error_command(self, create_redis): + """ + Test that the default node is being replaced when it raises a relevant exception + """ + r = await create_redis(cls=RedisCluster, flushdb=False) + curr_default_node = r.get_default_node() + err = ConnectionError("error") + cmd_count = await r.command_count() + mock_node_resp_exc(curr_default_node, err) + async with r.pipeline(transaction=False) as pipe: + pipe.command_count() + result = await pipe.execute(raise_on_error=False) + assert result[0] == err + assert r.get_default_node() != curr_default_node + pipe.command_count() + result = await pipe.execute(raise_on_error=False) + assert result[0] == cmd_count + @pytest.mark.ssl class TestSSL: diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 56387fa954..15c3ec5162 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -97,6 +97,14 @@ async def test_float_timeout(self, r): assert 8 < (await r.pttl("foo")) <= 9500 await lock.release() + async def test_blocking(self, r): + blocking = False + lock = self.get_lock(r, "foo", blocking=blocking) + assert not lock.blocking + + lock_2 = self.get_lock(r, "foo") + assert lock_2.blocking + async def test_blocking_timeout(self, r, event_loop): lock1 = self.get_lock(r, "foo") assert await lock1.acquire(blocking=False) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d18fbbbb33..43aeb9e045 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None: == retry._retries ) + def test_replace_cluster_node(self, r) -> None: + prev_default_node = r.get_default_node() + r.replace_default_node() + assert r.get_default_node() != prev_default_node + r.replace_default_node(prev_default_node) + assert r.get_default_node() == prev_default_node + + def test_default_node_is_replaced_after_exception(self, r): + curr_default_node = r.get_default_node() + # CLUSTER NODES command is being executed on the default node + nodes = r.cluster_nodes() + assert "myself" in nodes.get(curr_default_node.name).get("flags") + + def raise_connection_error(): + raise ConnectionError("error") + + # Mock connection error for the default node + mock_node_resp_func(curr_default_node, raise_connection_error) + # Test that the command succeed from a different node + nodes = r.cluster_nodes() + assert "myself" not in nodes.get(curr_default_node.name).get("flags") + assert r.get_default_node() != curr_default_node + @pytest.mark.onlycluster class TestClusterRedisCommands: