diff --git a/aioredis/commands/set.py b/aioredis/commands/set.py index e29ebe5f6..6c20b97b1 100644 --- a/aioredis/commands/set.py +++ b/aioredis/commands/set.py @@ -43,9 +43,12 @@ def smove(self, sourcekey, destkey, member): """Move a member from one set to another.""" return self.execute(b'SMOVE', sourcekey, destkey, member) - def spop(self, key, *, encoding=_NOTSET): - """Remove and return a random member from a set.""" - return self.execute(b'SPOP', key, encoding=encoding) + def spop(self, key, count=None, *, encoding=_NOTSET): + """Remove and return one or multiple random members from a set.""" + args = [key] + if count is not None: + args.append(count) + return self.execute(b'SPOP', *args, encoding=encoding) def srandmember(self, key, count=None, *, encoding=_NOTSET): """Get one or multiple random members from a set.""" diff --git a/tests/set_commands_test.py b/tests/set_commands_test.py index d88a12975..d3c0ca4f8 100644 --- a/tests/set_commands_test.py +++ b/tests/set_commands_test.py @@ -1,5 +1,7 @@ import pytest +from aioredis import ReplyError + async def add(redis, key, members): ok = await redis.connection.execute(b'sadd', key, members) @@ -277,6 +279,42 @@ async def test_spop(redis): await redis.spop(None) +@pytest.redis_version( + 3, 2, 0, + reason="The count argument in SPOP is available since redis>=3.2.0" +) +@pytest.mark.run_loop +async def test_spop_count(redis): + key = b'key:spop:1' + members1 = b'one', b'two', b'three' + await redis.sadd(key, *members1) + + # fetch 3 random members + test_result1 = await redis.spop(key, 3) + assert len(test_result1) == 3 + assert set(test_result1).issubset(members1) is True + + members2 = 'four', 'five', 'six' + await redis.sadd(key, *members2) + + # test with encoding, fetch 3 random members + test_result2 = await redis.spop(key, 3, encoding='utf-8') + assert len(test_result2) == 3 + assert set(test_result2).issubset(members2) is True + + # try to pop data from empty set + test_result = await redis.spop(b'not:' + key, 2) + assert len(test_result) == 0 + + # test with negative counter + with pytest.raises(ReplyError): + await redis.spop(key, -2) + + # test with counter is zero + test_result3 = await redis.spop(key, 0) + assert len(test_result3) == 0 + + @pytest.mark.run_loop async def test_srandmember(redis): key = b'key:srandmember:1'