Skip to content

Commit

Permalink
Merge branch 'master' into ck-depupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
chayim committed Dec 1, 2022
2 parents ac36673 + c48dc83 commit 8080fa8
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGES
Expand Up @@ -29,6 +29,7 @@
* Added CredentialsProvider class to support password rotation
* Enable Lock for asyncio cluster mode
* Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458)
* Added a replacement for the default cluster node in the event of failure (#2463)

* 4.1.3 (Feb 8, 2022)
* Fix flushdb and flushall (#1926)
Expand Down
8 changes: 8 additions & 0 deletions redis/asyncio/client.py
Expand Up @@ -354,6 +354,7 @@ def lock(
name: KeyT,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
Expand All @@ -369,6 +370,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -411,6 +418,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down
30 changes: 29 additions & 1 deletion redis/asyncio/cluster.py
Expand Up @@ -517,6 +517,8 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
async def _determine_nodes(
self, command: str, *args: Any, node_flag: Optional[str] = None
) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
if not node_flag:
# get the nodes group for this command if it was predefined
node_flag = self.command_flags.get(command)
Expand Down Expand Up @@ -654,6 +656,12 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
for _ in range(execute_attempts):
if self._initialize:
await self.initialize()
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
# Replace the default cluster node
self.replace_default_node()
try:
if not target_nodes_specified:
# Determine the nodes to execute the command on
Expand Down Expand Up @@ -793,6 +801,7 @@ def lock(
name: KeyT,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Optional[Type[Lock]] = None,
thread_local: bool = True,
Expand All @@ -808,6 +817,12 @@ def lock(
when the lock is in blocking mode and another client is currently
holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
Expand Down Expand Up @@ -850,6 +865,7 @@ def lock(
name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
Expand Down Expand Up @@ -1450,7 +1466,6 @@ async def _execute(
)
if len(target_nodes) > 1:
raise RedisClusterException(f"Too many targets for command {cmd.args}")

node = target_nodes[0]
if node.name not in nodes:
nodes[node.name] = (node, [])
Expand Down Expand Up @@ -1487,6 +1502,19 @@ async def _execute(
result.args = (msg,) + result.args[1:]
raise result

default_node = nodes.get(client.get_default_node().name)
if default_node is not None:
# This pipeline execution used the default node, check if we need
# to replace it.
# Note: when the error is raised we'll reset the default node in the
# caller function.
for cmd in default_node[1]:
# Check if it has a command that failed with a relevant
# exception
if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY:
client.replace_default_node()
break

return [cmd.result for cmd in stack]

def _split_command_across_slots(
Expand Down
45 changes: 43 additions & 2 deletions redis/cluster.py
Expand Up @@ -379,6 +379,30 @@ class AbstractRedisCluster:

ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError)

def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
"""Replace the default cluster node.
A random cluster node will be chosen if target_node isn't passed, and primaries
will be prioritized. The default node will not be changed if there are no other
nodes in the cluster.
Args:
target_node (ClusterNode, optional): Target node to replace the default
node. Defaults to None.
"""
if target_node:
self.nodes_manager.default_node = target_node
else:
curr_node = self.get_default_node()
primaries = [node for node in self.get_primaries() if node != curr_node]
if primaries:
# Choose a primary if the cluster contains different primaries
self.nodes_manager.default_node = random.choice(primaries)
else:
# Otherwise, hoose a primary if the cluster contains different primaries
replicas = [node for node in self.get_replicas() if node != curr_node]
if replicas:
self.nodes_manager.default_node = random.choice(replicas)


class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
@classmethod
Expand Down Expand Up @@ -811,7 +835,9 @@ def set_response_callback(self, command, callback):
"""Set a custom Response Callback"""
self.cluster_response_callbacks[command] = callback

def _determine_nodes(self, *args, **kwargs):
def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]:
# Determine which nodes should be executed the command on.
# Returns a list of target nodes.
command = args[0].upper()
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
command = f"{args[0]} {args[1]}".upper()
Expand Down Expand Up @@ -990,6 +1016,7 @@ def execute_command(self, *args, **kwargs):
dict<Any, ClusterNode>
"""
target_nodes_specified = False
is_default_node = False
target_nodes = None
passed_targets = kwargs.pop("target_nodes", None)
if passed_targets is not None and not self._is_nodes_flag(passed_targets):
Expand Down Expand Up @@ -1020,12 +1047,20 @@ def execute_command(self, *args, **kwargs):
raise RedisClusterException(
f"No targets were found to execute {args} command on"
)
if (
len(target_nodes) == 1
and target_nodes[0] == self.get_default_node()
):
is_default_node = True
for node in target_nodes:
res[node.name] = self._execute_command(node, *args, **kwargs)
# Return the processed result
return self._process_result(args[0], res, **kwargs)
except Exception as e:
if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
if is_default_node:
# Replace the default cluster node
self.replace_default_node()
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
retry_attempts -= 1
Expand Down Expand Up @@ -1883,7 +1918,7 @@ def _send_cluster_commands(
# if we have to run through it again, we only retry
# the commands that failed.
attempt = sorted(stack, key=lambda x: x.position)

is_default_node = False
# build a list of node objects based on node names we need to
nodes = {}

Expand Down Expand Up @@ -1913,6 +1948,8 @@ def _send_cluster_commands(
)

node = target_nodes[0]
if node == self.get_default_node():
is_default_node = True

# now that we know the name of the node
# ( it's just a string in the form of host:port )
Expand All @@ -1926,6 +1963,8 @@ def _send_cluster_commands(
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
Expand Down Expand Up @@ -2007,6 +2046,8 @@ def _send_cluster_commands(
self.reinitialize_counter += 1
if self._should_reinitialized():
self.nodes_manager.initialize()
if is_default_node:
self.replace_default_node()
for c in attempt:
try:
# send each command individually like we
Expand Down
2 changes: 1 addition & 1 deletion redis/commands/core.py
Expand Up @@ -5581,7 +5581,7 @@ def _geosearchgeneric(
"GEOSEARCH member and longitude or latitude" " cant be set together"
)
pieces.extend([b"FROMMEMBER", kwargs["member"]])
if kwargs["longitude"] and kwargs["latitude"]:
if kwargs["longitude"] is not None and kwargs["latitude"] is not None:
pieces.extend([b"FROMLONLAT", kwargs["longitude"], kwargs["latitude"]])

# BYRADIUS or BYBOX
Expand Down
40 changes: 40 additions & 0 deletions tests/test_asyncio/test_cluster.py
Expand Up @@ -788,6 +788,27 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non
)
await rc.close()

def test_replace_cluster_node(self, r: RedisCluster) -> None:
prev_default_node = r.get_default_node()
r.replace_default_node()
assert r.get_default_node() != prev_default_node
r.replace_default_node(prev_default_node)
assert r.get_default_node() == prev_default_node

async def test_default_node_is_replaced_after_exception(self, r):
curr_default_node = r.get_default_node()
# CLUSTER NODES command is being executed on the default node
nodes = await r.cluster_nodes()
assert "myself" in nodes.get(curr_default_node.name).get("flags")
# Mock connection error for the default node
mock_node_resp_exc(curr_default_node, ConnectionError("error"))
# Test that the command succeed from a different node
nodes = await r.cluster_nodes()
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node
# Rollback to the old default node
r.replace_default_node(curr_default_node)


class TestClusterRedisCommands:
"""
Expand Down Expand Up @@ -2591,6 +2612,25 @@ async def test_can_run_concurrent_pipelines(self, r: RedisCluster) -> None:
*(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)),
)

@pytest.mark.onlycluster
async def test_pipeline_with_default_node_error_command(self, create_redis):
"""
Test that the default node is being replaced when it raises a relevant exception
"""
r = await create_redis(cls=RedisCluster, flushdb=False)
curr_default_node = r.get_default_node()
err = ConnectionError("error")
cmd_count = await r.command_count()
mock_node_resp_exc(curr_default_node, err)
async with r.pipeline(transaction=False) as pipe:
pipe.command_count()
result = await pipe.execute(raise_on_error=False)
assert result[0] == err
assert r.get_default_node() != curr_default_node
pipe.command_count()
result = await pipe.execute(raise_on_error=False)
assert result[0] == cmd_count


@pytest.mark.ssl
class TestSSL:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_asyncio/test_lock.py
Expand Up @@ -97,6 +97,14 @@ async def test_float_timeout(self, r):
assert 8 < (await r.pttl("foo")) <= 9500
await lock.release()

async def test_blocking(self, r):
blocking = False
lock = self.get_lock(r, "foo", blocking=blocking)
assert not lock.blocking

lock_2 = self.get_lock(r, "foo")
assert lock_2.blocking

async def test_blocking_timeout(self, r, event_loop):
lock1 = self.get_lock(r, "foo")
assert await lock1.acquire(blocking=False)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_cluster.py
Expand Up @@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None:
== retry._retries
)

def test_replace_cluster_node(self, r) -> None:
prev_default_node = r.get_default_node()
r.replace_default_node()
assert r.get_default_node() != prev_default_node
r.replace_default_node(prev_default_node)
assert r.get_default_node() == prev_default_node

def test_default_node_is_replaced_after_exception(self, r):
curr_default_node = r.get_default_node()
# CLUSTER NODES command is being executed on the default node
nodes = r.cluster_nodes()
assert "myself" in nodes.get(curr_default_node.name).get("flags")

def raise_connection_error():
raise ConnectionError("error")

# Mock connection error for the default node
mock_node_resp_func(curr_default_node, raise_connection_error)
# Test that the command succeed from a different node
nodes = r.cluster_nodes()
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
assert r.get_default_node() != curr_default_node


@pytest.mark.onlycluster
class TestClusterRedisCommands:
Expand Down

0 comments on commit 8080fa8

Please sign in to comment.