From 8f3890c69a96b9e58856d5cbd340ce3347e3975a Mon Sep 17 00:00:00 2001 From: Ian Good Date: Mon, 29 Nov 2021 18:59:42 -0500 Subject: [PATCH] Fix typing on blpop (etc) timeout argument --- CHANGES/1224.bugfix | 1 + aioredis/client.py | 11 ++++++----- tests/conftest.py | 10 +++++++--- 3 files changed, 14 insertions(+), 8 deletions(-) create mode 100644 CHANGES/1224.bugfix diff --git a/CHANGES/1224.bugfix b/CHANGES/1224.bugfix new file mode 100644 index 000000000..f183fe12a --- /dev/null +++ b/CHANGES/1224.bugfix @@ -0,0 +1 @@ +Fix typing on blpop (etc) timeout argument diff --git a/aioredis/client.py b/aioredis/client.py index aa8fbd727..26c9ae3c7 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -65,6 +65,7 @@ ConsumerT = _StringLikeT # Consumer name StreamIdT = Union[int, _StringLikeT] ScriptTextT = _StringLikeT +TimeoutSecT = Union[int, float, _StringLikeT] # Mapping is not covariant in the key type, which prevents # Mapping[_StringLikeT, X from accepting arguments of type Dict[str, X]. Using # a TypeVar instead of a Union allows mappings with any of the permitted types @@ -2124,7 +2125,7 @@ def unlink(self, *names: KeyT) -> Awaitable: return self.execute_command("UNLINK", *names) # LIST COMMANDS - def blpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: + def blpop(self, keys: KeysT, timeout: TimeoutSecT = 0) -> Awaitable: """ LPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2137,7 +2138,7 @@ def blpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ return self.execute_command("BLPOP", *list_or_args(keys, (timeout,))) - def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: + def brpop(self, keys: KeysT, timeout: TimeoutSecT = 0) -> Awaitable: """ RPOP a value off of the first non-empty list named in the ``keys`` list. @@ -2150,7 +2151,7 @@ def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ return self.execute_command("BRPOP", *list_or_args(keys, (timeout,))) - def brpoplpush(self, src: KeyT, dst: KeyT, timeout: int = 0) -> Awaitable: + def brpoplpush(self, src: KeyT, dst: KeyT, timeout: TimeoutSecT = 0) -> Awaitable: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` and then return it. @@ -3140,7 +3141,7 @@ def zpopmin(self, name: KeyT, count: Optional[int] = None) -> Awaitable: options = {"withscores": True} return self.execute_command("ZPOPMIN", name, *args, **options) - def bzpopmax(self, keys: KeysT, timeout: int = 0) -> Awaitable: + def bzpopmax(self, keys: KeysT, timeout: TimeoutSecT = 0) -> Awaitable: """ ZPOPMAX a value off of the first non-empty sorted set named in the ``keys`` list. @@ -3154,7 +3155,7 @@ def bzpopmax(self, keys: KeysT, timeout: int = 0) -> Awaitable: parsed_keys = list_or_args(keys, (timeout,)) return self.execute_command("BZPOPMAX", *parsed_keys) - def bzpopmin(self, keys: KeysT, timeout: int = 0) -> Awaitable: + def bzpopmin(self, keys: KeysT, timeout: TimeoutSecT = 0) -> Awaitable: """ ZPOPMIN a value off of the first non-empty sorted set named in the ``keys`` list. diff --git a/tests/conftest.py b/tests/conftest.py index 9d8520545..3de27b05b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import asyncio import random from distutils.version import StrictVersion +from typing import Callable, TypeVar from urllib.parse import urlparse import pytest @@ -25,6 +26,9 @@ REDIS_INFO = {} default_redis_url = "redis://localhost:6379/9" +_DecoratedTest = TypeVar("_DecoratedTest", bound="Callable") +_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] + # Taken from python3.9 class BooleanOptionalAction(argparse.Action): @@ -111,19 +115,19 @@ def pytest_sessionstart(session): REDIS_INFO["arch_bits"] = arch_bits -def skip_if_server_version_lt(min_version): +def skip_if_server_version_lt(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = StrictVersion(redis_version) < StrictVersion(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") -def skip_if_server_version_gte(min_version): +def skip_if_server_version_gte(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = StrictVersion(redis_version) >= StrictVersion(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") -def skip_unless_arch_bits(arch_bits): +def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit",