diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b609de39..500a3e7af 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,10 @@ jobs: uses: py-actions/py-dependency-install@v2.1.0 with: path: tests/requirements.txt + - name: Run mypy + run: | + pip install -r tests/requirements-mypy.txt + mypy - name: Run linter run: | make lint diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 000000000..d2bc5b09e --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,26 @@ +[mypy] +#, docs/examples, tests +files = aioredis +check_untyped_defs = True +follow_imports_for_stubs = True +#disallow_any_decorated = True +#disallow_any_generics = True +#disallow_incomplete_defs = True +disallow_subclassing_any = True +#disallow_untyped_calls = True +disallow_untyped_decorators = True +#disallow_untyped_defs = True +implicit_reexport = False +no_implicit_optional = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_ignores = True +disallow_any_unimported = True +#warn_return_any = True + +[mypy-aioredis.lock] +# TODO: Remove once locks has been rewritten +ignore_errors = True diff --git a/CHANGES/1101.misc b/CHANGES/1101.misc new file mode 100644 index 000000000..00d0b7afd --- /dev/null +++ b/CHANGES/1101.misc @@ -0,0 +1 @@ +Enable Mypy in CI. diff --git a/Makefile b/Makefile index 75e61f0d7..d1ba50491 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ spelling: $(MAKE) -C docs spelling mypy: - $(MYPY) aioredis --ignore-missing-imports + $(MYPY) test: $(PYTEST) @@ -38,7 +38,7 @@ init-hooks: pre-commit install-hooks lint: init-hooks - pre-commit run --all-files + pre-commit run --all-files --show-diff-on-failure devel: aioredis.egg-info init-hooks pip install -U pip diff --git a/aioredis/client.py b/aioredis/client.py index 972270b41..3d0ba7aee 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -5,8 +5,8 @@ import re import time as mod_time import warnings -from itertools import chain from typing import ( + AbstractSet, Any, AsyncIterator, Awaitable, @@ -15,6 +15,7 @@ Iterable, List, Mapping, + MutableMapping, NoReturn, Optional, Sequence, @@ -23,6 +24,8 @@ Type, TypeVar, Union, + ValuesView, + cast, ) from aioredis.compat import Protocol, TypedDict @@ -56,7 +59,7 @@ KeyT = _StringLikeT # Main redis key space PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands -KeysT = Union[KeyT, Iterable[KeyT]] +KeysT = Union[KeyT, Sequence[KeyT]] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name @@ -69,27 +72,37 @@ # type signature because they will all be required to be the same key type. AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) -AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +AnyChannelT = ChannelT +PubSubHandler = Callable[[Dict[str, str]], None] SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) -def list_or_args(keys: KeysT, args: Optional[KeysT]) -> KeysT: + +def list_or_args( + keys: Union[_KeyT, Iterable[_KeyT]], args: Optional[Iterable[_ArgT]] +) -> List[Union[_KeyT, _ArgT]]: # returns a single new list combining keys and args + key_list: List[Union[_KeyT, _ArgT]] try: - iter(keys) + iter(keys) # type: ignore[arg-type] + keys = cast(Iterable[_KeyT], keys) # a string or bytes instance can be iterated, but indicates # keys wasn't passed as a list if isinstance(keys, (bytes, str)): - keys = [keys] + key_list = [keys] else: - keys = list(keys) + key_list = list(keys) except TypeError: - keys = [keys] + key_list = [cast(memoryview, keys)] if args: - keys.extend(args) - return keys + key_list.extend(args) + return key_list def timestamp_to_datetime(response): @@ -161,7 +174,7 @@ def parse_object(response, infotype): def parse_info(response): """Parse the result of Redis's INFO command into a Python dict""" - info = {} + info: Dict[str, Any] = {} response = str_if_bytes(response) def get_value(value): @@ -449,7 +462,7 @@ def parse_zscan(response, **options): def parse_slowlog_get(response, **options): - space = " " if options.get("decode_responses", False) else b" " + space: Union[str, bytes] = " " if options.get("decode_responses", False) else b" " return [ { "id": item[0], @@ -510,7 +523,7 @@ def parse_georadius_generic(response, **options): # just a bunch of places return response_list - cast = { + cast: Dict[str, Callable] = { "withdist": float, "withcoord": lambda ll: (float(ll[0]), float(ll[1])), "withhash": int, @@ -770,6 +783,8 @@ class Redis: "ZSCAN": parse_zscan, } + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + @classmethod def from_url(cls, url: str, **kwargs): """ @@ -892,7 +907,7 @@ def __init__( connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool self.single_connection_client = single_connection_client - self.connection = None + self.connection: Optional[Connection] = None self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) @@ -902,7 +917,7 @@ def __repr__(self): def __await__(self): return self.initialize().__await__() - async def initialize(self): + async def initialize(self: _RedisT) -> _RedisT: if self.single_connection_client and self.connection is None: self.connection = await self.connection_pool.get_connection("_") return self @@ -1033,7 +1048,8 @@ def client(self) -> "Redis": connection_pool=self.connection_pool, single_connection_client=True ) - __aenter__ = initialize + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() async def __aexit__(self, exc_type, exc_value, traceback): await self.close() @@ -1085,6 +1101,8 @@ async def parse_response( return options[EMPTY_RESPONSE] raise if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) retval = self.response_callbacks[command_name](response, **options) return await retval if inspect.isawaitable(retval) else retval return response @@ -1261,9 +1279,9 @@ def acl_setuser( # noqa: C901 if passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list - passwords = list_or_args(passwords, []) - for i, password in enumerate(passwords): - password = encoder.encode(password) + converted_passwords = list_or_args(passwords, []) + for i, raw_password in enumerate(converted_passwords): + password = encoder.encode(raw_password) if password.startswith(b"+"): pieces.append(b">%s" % password[1:]) elif password.startswith(b"-"): @@ -1277,9 +1295,9 @@ def acl_setuser( # noqa: C901 if hashed_passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list - hashed_passwords = list_or_args(hashed_passwords, []) - for i, hashed_password in enumerate(hashed_passwords): - hashed_password = encoder.encode(hashed_password) + parsed_hashed_passwords = list_or_args(hashed_passwords, []) + for i, raw_hashed_password in enumerate(parsed_hashed_passwords): + hashed_password = encoder.encode(raw_hashed_password) if hashed_password.startswith(b"+"): pieces.append(b"#%s" % hashed_password[1:]) elif hashed_password.startswith(b"-"): @@ -1294,8 +1312,8 @@ def acl_setuser( # noqa: C901 pieces.append(b"nopass") if categories: - for category in categories: - category = encoder.encode(category) + for raw_category in categories: + category = encoder.encode(raw_category) # categories can be prefixed with one of (+@, +, -@, -) if category.startswith(b"+@"): pieces.append(category) @@ -1311,8 +1329,8 @@ def acl_setuser( # noqa: C901 'prefixed with "+" or "-"' ) if commands: - for cmd in commands: - cmd = encoder.encode(cmd) + for raw_cmd in commands: + cmd = encoder.encode(raw_cmd) if not cmd.startswith(b"+") and not cmd.startswith(b"-"): raise DataError( f'Command "{encoder.decode(cmd, force=True)}" must be ' @@ -1321,8 +1339,8 @@ def acl_setuser( # noqa: C901 pieces.append(cmd) if keys: - for key in keys: - key = encoder.encode(key) + for raw_key in keys: + key = encoder.encode(raw_key) pieces.append(b"~%s" % key) return self.execute_command("ACL SETUSER", *pieces) @@ -1367,7 +1385,7 @@ def client_kill_filter( will not get killed even if it is identified by one of the filter options. If skipme is not provided, the server defaults to skipme=True """ - args = [] + args: List[Union[bytes, str]] = [] if _type is not None: client_types = ("normal", "master", "slave", "pubsub") if str(_type).lower() not in client_types: @@ -1639,7 +1657,7 @@ def sentinel_slaves(self, service_name: str) -> Awaitable: """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) - def shutdown(self, save: bool = False, nosave: bool = False) -> Awaitable: + def shutdown(self, save: bool = False, nosave: bool = False) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -1766,12 +1784,12 @@ def bitpos( raise DataError("bit must be 0 or 1") params = [key, bit] - start is not None and params.append(start) - - if start is not None and end is not None: - params.append(end) - elif start is None and end is not None: - raise DataError("start argument is not set, " "when end is specified") + if start is not None: + params.append(start) + if end is not None: + params.append(end) + elif end is not None: + raise DataError("start argument is not set, when end is specified") return self.execute_command("BITPOS", *params) def decr(self, name: KeyT, amount: int = 1) -> Awaitable: @@ -1878,11 +1896,11 @@ def mget(self, keys: KeysT, *args: EncodableT) -> Awaitable: """ Returns a list of values ordered identically to ``keys`` """ - args = list_or_args(keys, args) + encoded_args = list_or_args(keys, args) options: Dict[str, Union[EncodableT, Iterable[EncodableT]]] = {} - if not args: + if not encoded_args: options[EMPTY_RESPONSE] = [] - return self.execute_command("MGET", *args, **options) + return self.execute_command("MGET", *encoded_args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> Awaitable: """ @@ -2117,11 +2135,7 @@ def blpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BLPOP", *keys) + return self.execute_command("BLPOP", *list_or_args(keys, (timeout,))) def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ @@ -2134,11 +2148,7 @@ def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BRPOP", *keys) + return self.execute_command("BRPOP", *list_or_args(keys, (timeout,))) def brpoplpush(self, src: KeyT, dst: KeyT, timeout: int = 0) -> Awaitable: """ @@ -2149,8 +2159,6 @@ def brpoplpush(self, src: KeyT, dst: KeyT, timeout: int = 0) -> Awaitable: seconds elapse, whichever is first. A ``timeout`` value of 0 blocks forever. """ - if timeout is None: - timeout = 0 return self.execute_command("BRPOPLPUSH", src, dst, timeout) def lindex(self, name: KeyT, index: int) -> Awaitable: @@ -2362,8 +2370,10 @@ def sort( "must be specified and contain at least " "two keys" ) + options: Dict[str, Optional[int]] = {"groups": len(get)} + else: + options = {"groups": None} - options = {"groups": len(get) if groups else None} return self.execute_command("SORT", *pieces, **options) # SCAN COMMANDS @@ -2417,10 +2427,10 @@ async def scan_iter( HASH, LIST, SET, STREAM, STRING, ZSET Additionally, Redis modules can expose other types as well. """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.scan( - cursor=cursor, match=match, count=count, _type=_type + cursor=cursor or 0, match=match, count=count, _type=_type ) for d in data: yield d @@ -2458,10 +2468,10 @@ async def sscan_iter( ``count`` allows for hint the minimum number of returns """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.sscan( - name, cursor=cursor, match=match, count=count + name, cursor=cursor or 0, match=match, count=count ) for d in data: yield d @@ -2499,10 +2509,10 @@ async def hscan_iter( ``count`` allows for hint the minimum number of returns """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.hscan( - name, cursor=cursor, match=match, count=count + name, cursor=cursor or 0, match=match, count=count ) for it in data.items(): yield it @@ -2550,11 +2560,11 @@ async def zscan_iter( ``score_cast_func`` a callable used to cast the score return value """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.zscan( name, - cursor=cursor, + cursor=cursor or 0, match=match, count=count, score_cast_func=score_cast_func, @@ -2573,29 +2583,29 @@ def scard(self, name: KeyT) -> Awaitable: def sdiff(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the difference of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SDIFF", *parsed_args) def sdiffstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SDIFFSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SDIFFSTORE", dest, *parsed_args) def sinter(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the intersection of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SINTER", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SINTER", *parsed_args) def sinterstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SINTERSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SINTERSTORE", dest, *parsed_args) def sismember(self, name: KeyT, value: EncodableT) -> Awaitable: """Return a boolean indicating if ``value`` is a member of set ``name``""" @@ -2631,16 +2641,16 @@ def srem(self, name: KeyT, *values: EncodableT) -> Awaitable: def sunion(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the union of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SUNION", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SUNION", *parsed_args) def sunionstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SUNIONSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SUNIONSTORE", dest, *parsed_args) # STREAMS COMMANDS def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> Awaitable: @@ -2847,9 +2857,9 @@ def xpending_range( self, name: KeyT, groupname: GroupT, - min: StreamIdT, - max: StreamIdT, - count: int, + min: Optional[StreamIdT], + max: Optional[StreamIdT], + count: Optional[int], consumername: Optional[ConsumerT] = None, ) -> Awaitable: """ @@ -3141,11 +3151,8 @@ def bzpopmax(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BZPOPMAX", *keys) + parsed_keys = list_or_args(keys, (timeout,)) + return self.execute_command("BZPOPMAX", *parsed_keys) def bzpopmin(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ @@ -3158,8 +3165,6 @@ def bzpopmin(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 klist: List[EncodableT] = list_or_args(keys, None) klist.append(timeout) return self.execute_command("BZPOPMIN", *klist) @@ -3397,11 +3402,14 @@ def _zaggregate( aggregate: Optional[str] = None, ) -> Awaitable: pieces: List[EncodableT] = [command, dest, len(keys)] - if isinstance(keys, dict): - keys, weights = keys.keys(), keys.values() + key_names: Union[Sequence[KeyT], AbstractSet[AnyKeyT]] + weights: Optional[ValuesView[float]] + if isinstance(keys, Mapping): + key_names, weights = keys.keys(), keys.values() else: + key_names = keys weights = None - pieces.extend(keys) + pieces.extend(key_names) if weights: pieces.append(b"WEIGHTS") pieces.extend(weights) @@ -3476,7 +3484,7 @@ def hset( """ if key is None and not mapping: raise DataError("'hset' with no key value pairs") - items: List[EncodableT] = [] + items: List[Union[FieldT, Optional[EncodableT]]] = [] if key is not None: items.extend((key, value)) if mapping: @@ -3505,15 +3513,15 @@ def hmset(self, name: KeyT, mapping: Mapping[AnyFieldT, EncodableT]) -> Awaitabl ) if not mapping: raise DataError("'hmset' with 'mapping' of length 0") - items = [] + items: List[Union[AnyFieldT, EncodableT]] = [] for pair in mapping.items(): items.extend(pair) return self.execute_command("HMSET", name, *items) - def hmget(self, name: KeyT, keys: Sequence[FieldT], *args: FieldT) -> Awaitable: + def hmget(self, name: KeyT, keys: Sequence[KeyT], *args: FieldT) -> Awaitable: """Returns a list of values ordered identically to ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("HMGET", name, *parsed_args) def hvals(self, name: KeyT) -> Awaitable: """Return the list of values within hash ``name``""" @@ -3749,7 +3757,7 @@ def georadiusbymember( def _georadiusgeneric( self, command: str, *args: EncodableT, **kwargs: Optional[EncodableT] ) -> Awaitable: - pieces: List[EncodableT] = list(args) + pieces: List[Optional[EncodableT]] = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") elif kwargs["unit"]: @@ -3844,6 +3852,7 @@ async def connect(self): async def __aenter__(self): await self.connect() + self.connection = cast(Connection, self.connection) # Connected above. await self.connection.send_command("MONITOR") # check that monitor returns 'OK', but don't return it to user response = await self.connection.read_response() @@ -3852,17 +3861,22 @@ async def __aenter__(self): return self async def __aexit__(self, *args): + assert self.connection is not None await self.connection.disconnect() await self.connection_pool.release(self.connection) async def next_command(self) -> MonitorCommandInfo: """Parse the response from a monitor command""" + if self.connection is None: + raise RedisError("Connection already closed.") await self.connect() response = await self.connection.read_response() if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) + if m is None: + raise RedisError("Invalid command received.") db_id, client_info, command = m.groups() command = " ".join(self.command_re.findall(command)) # Redis escapes double quotes because each piece of the command @@ -3919,7 +3933,7 @@ def __init__( self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages - self.connection = None + self.connection: Optional[Connection] = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() @@ -3933,10 +3947,10 @@ def __init__( b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] - self.channels = {} - self.pending_unsubscribe_channels = set() - self.patterns = {} - self.pending_unsubscribe_patterns = set() + self.channels: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_channels: Set[ChannelT] = set() + self.patterns: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_patterns: Set[ChannelT] = set() self._lock = asyncio.Lock() async def __aenter__(self): @@ -4056,7 +4070,7 @@ async def check_health(self): "PING", self.HEALTH_CHECK_MESSAGE, check_health=False ) - def _normalize_keys(self, data: Mapping[AnyChannelT, EncodableT]): + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: """ normalize channel/pattern names to be either bytes or strings based on whether responses are automatically decoded. this saves us @@ -4064,9 +4078,9 @@ def _normalize_keys(self, data: Mapping[AnyChannelT, EncodableT]): """ encode = self.encoder.encode decode = self.encoder.decode - return {decode(encode(k)): v for k, v in data.items()} + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] - async def psubscribe(self, *args: ChannelT, **kwargs: Callable): + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): """ Subscribe to channel patterns. Patterns supplied as keyword arguments expect a pattern name as the key and a callable as the value. A @@ -4074,10 +4088,10 @@ async def psubscribe(self, *args: ChannelT, **kwargs: Callable): received on that pattern rather than producing a message via ``listen()``. """ - if args: - args = list_or_args(args[0], args[1:]) - new_patterns: Dict[ChannelT, Optional[Callable]] = dict.fromkeys(args) - new_patterns.update(kwargs) + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again @@ -4087,18 +4101,20 @@ async def psubscribe(self, *args: ChannelT, **kwargs: Callable): self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val - def punsubscribe(self, *args: EncodableT) -> Awaitable: + def punsubscribe(self, *args: ChannelT) -> Awaitable: """ Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ + patterns: Iterable[ChannelT] if args: - args = list_or_args(args[0], args[1:]) - patterns = self._normalize_keys(dict.fromkeys(args)) + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() else: + parsed_args = [] patterns = self.patterns self.pending_unsubscribe_patterns.update(patterns) - return self.execute_command("PUNSUBSCRIBE", *args) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) async def subscribe(self, *args: ChannelT, **kwargs: Callable): """ @@ -4108,10 +4124,10 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): that channel rather than producing a message via ``listen()`` or ``get_message()``. """ - if args: - args = list_or_args(args[0], args[1:]) - new_channels = dict.fromkeys(args) - new_channels.update(kwargs) + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again @@ -4127,12 +4143,13 @@ def unsubscribe(self, *args) -> Awaitable: all channels """ if args: - args = list_or_args(args[0], args[1:]) - channels = self._normalize_keys(dict.fromkeys(args)) + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) else: + parsed_args = [] channels = self.channels self.pending_unsubscribe_channels.update(channels) - return self.execute_command("UNSUBSCRIBE", *args) + return self.execute_command("UNSUBSCRIBE", *parsed_args) async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" @@ -4225,7 +4242,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): async def run( self, *, - exception_handler: "PSWorkerThreadExcHandlerT" = None, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, ) -> None: """Process pub/sub messages using registered callbacks. @@ -4309,21 +4326,21 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] def __init__( self, connection_pool: ConnectionPool, - response_callbacks: Mapping[str, ResponseCallbackT], + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], transaction: bool, shard_hint: Optional[str], ): self.connection_pool = connection_pool self.connection = None self.response_callbacks = response_callbacks - self.transaction = transaction + self.is_transaction = transaction self.shard_hint = shard_hint self.watching = False self.command_stack: CommandStackT = [] self.scripts: Set[Script] = set() self.explicit_transaction = False - async def __aenter__(self) -> "Pipeline": + async def __aenter__(self: _RedisT) -> _RedisT: return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -4412,6 +4429,7 @@ async def immediate_execute_command(self, *args, **options): command_name, self.shard_hint ) self.connection = conn + conn = cast(Connection, conn) try: await conn.send_command(*args) return await self.parse_response(conn, command_name, **options) @@ -4459,12 +4477,14 @@ def pipeline_execute_command(self, *args, **options): self.command_stack.append((args, options)) return self - async def _execute_transaction( + async def _execute_transaction( # noqa: C901 self, connection: Connection, commands: CommandStackT, raise_on_error ): - cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) all_cmds = connection.pack_commands( - [args for args, options in cmds if EMPTY_RESPONSE not in options] + args for args, options in cmds if EMPTY_RESPONSE not in options ) await connection.send_packed_command(all_cmds) errors = [] @@ -4508,7 +4528,8 @@ async def _execute_transaction( response.insert(i, e) if len(response) != len(commands): - await self.connection.disconnect() + if self.connection: + await self.connection.disconnect() raise ResponseError( "Wrong number of response items from pipeline execution" ) from None @@ -4556,7 +4577,9 @@ def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): self.annotate_exception(r, i + 1, commands[i][0]) raise r - def annotate_exception(self, exception: Exception, number: int, command: str): + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: cmd = " ".join(map(safe_str, command)) msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" exception.args = (msg,) + exception.args[1:] @@ -4591,7 +4614,7 @@ async def execute(self, raise_on_error: bool = True): return [] if self.scripts: await self.load_scripts() - if self.transaction or self.explicit_transaction: + if self.is_transaction or self.explicit_transaction: execute = self._execute_transaction else: execute = self._execute_pipeline @@ -4602,6 +4625,7 @@ async def execute(self, raise_on_error: bool = True): # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn + conn = cast(Connection, conn) try: return await execute(conn, stack, raise_on_error) @@ -4646,8 +4670,10 @@ def __init__(self, registered_client: Redis, script: ScriptTextT): # We need the encoding from the client in order to generate an # accurate byte representation of the script encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() + script_bytes = encoder.encode(script) + else: + script_bytes = script + self.sha = hashlib.sha1(script_bytes).hexdigest() async def __call__( self, @@ -4681,7 +4707,9 @@ class BitFieldOperation: Command builder for BITFIELD commands. """ - def __init__(self, client: Redis, key: str, default_overflow: Optional[str] = None): + def __init__( + self, client: Redis, key: KeyT, default_overflow: Optional[str] = None + ): self.client = client self.key = key self._default_overflow = default_overflow @@ -4765,7 +4793,7 @@ def set(self, fmt: str, offset: BitfieldOffsetT, value: int): @property def command(self): - cmd = ["BITFIELD", self.key] + cmd: List[EncodableT] = ["BITFIELD", self.key] for ops in self.operations: cmd.extend(ops) return cmd diff --git a/aioredis/compat.py b/aioredis/compat.py index 881fda8ff..ae636028d 100644 --- a/aioredis/compat.py +++ b/aioredis/compat.py @@ -1,5 +1,8 @@ -# flake8: noqa -try: - from typing import Protocol, TypedDict # lgtm [py/unused-import] -except ImportError: - from typing_extensions import Protocol, TypedDict # lgtm [py/unused-import] +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol, TypedDict +else: + from typing_extensions import Protocol, TypedDict + +__all__ = ("Protocol", "TypedDict") diff --git a/aioredis/connection.py b/aioredis/connection.py index 6ec3741af..93910971b 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -1,4 +1,5 @@ import asyncio +import enum import errno import inspect import io @@ -9,8 +10,10 @@ import warnings from distutils.version import StrictVersion from itertools import chain +from types import MappingProxyType from typing import ( Any, + Callable, Iterable, List, Mapping, @@ -20,6 +23,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -78,7 +82,12 @@ SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." -SENTINEL = object() + +class _Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = _Sentinel.sentinel MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." @@ -93,6 +102,13 @@ EncodableT = Union[EncodedT, DecodedT] +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + class Encoder: """Encode strings to bytes-like and decode bytes-like to strings""" @@ -117,14 +133,12 @@ def encode(self, value: EncodableT) -> EncodedT: return repr(value).encode() if not isinstance(value, str): # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ + typename = value.__class__.__name__ # type: ignore[unreachable] raise DataError( f"Invalid input of type: {typename!r}. " "Convert to a bytes, string, int or float first." ) - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - return value + return value.encode(self.encoding, self.encoding_errors) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" @@ -184,9 +198,11 @@ def parse_error(self, response: str) -> ResponseError: error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) + exception_class_or_dict = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class_or_dict, dict): + exception_class = exception_class_or_dict.get(response, ResponseError) + else: + exception_class = exception_class_or_dict return exception_class(response) return ResponseError(response) @@ -199,7 +215,9 @@ def on_connect(self, connection: "Connection"): async def can_read(self, timeout: float) -> bool: raise NotImplementedError() - async def read_response(self) -> Union[EncodableT, ResponseError, None]: + async def read_response( + self, + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: raise NotImplementedError() @@ -215,12 +233,12 @@ def __init__( self, stream_reader: asyncio.StreamReader, socket_read_size: int, - socket_timeout: float, + socket_timeout: Optional[float], ): - self._stream = stream_reader + self._stream: Optional[asyncio.StreamReader] = stream_reader self.socket_read_size = socket_read_size self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() + self._buffer: Optional[io.BytesIO] = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 # number of bytes read from the buffer @@ -233,10 +251,12 @@ def length(self): async def _read_from_socket( self, length: Optional[int] = None, - timeout: Optional[float] = SENTINEL, # type: ignore + timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ) -> bool: buf = self._buffer + if buf is None or self._stream is None: + raise RedisError("Buffer is closed.") buf.seek(self.bytes_written) marker = 0 timeout = timeout if timeout is not SENTINEL else self.socket_timeout @@ -281,6 +301,9 @@ async def read(self, length: int) -> bytes: if length > self.length: await self._read_from_socket(length - self.length) + if self._buffer is None: + raise RedisError("Buffer is closed.") + self._buffer.seek(self.bytes_read) data = self._buffer.read(length) self.bytes_read += len(data) @@ -294,6 +317,9 @@ async def read(self, length: int) -> bytes: async def readline(self) -> bytes: buf = self._buffer + if buf is None: + raise RedisError("Buffer is closed.") + buf.seek(self.bytes_read) data = buf.readline() while not data.endswith(SYM_CRLF): @@ -312,6 +338,9 @@ async def readline(self) -> bytes: return data[:-2] def purge(self): + if self._buffer is None: + raise RedisError("Buffer is closed.") + self._buffer.seek(0) self._buffer.truncate() self.bytes_written = 0 @@ -320,7 +349,7 @@ def purge(self): def close(self): try: self.purge() - self._buffer.close() + self._buffer.close() # type: ignore[union-attr] except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -344,6 +373,9 @@ def __init__(self, socket_read_size: int): def on_connect(self, connection: "Connection"): """Called when the stream connects""" self._stream = connection._reader + if self._buffer is None or self._stream is None: + raise RedisError("Buffer is closed.") + self._buffer = SocketBuffer( self._stream, self._read_size, connection.socket_timeout ) @@ -362,7 +394,7 @@ async def can_read(self, timeout: float): return self._buffer and bool(await self._buffer.can_read(timeout)) async def read_response(self) -> Union[EncodableT, ResponseError, None]: - if not self._buffer: + if not self._buffer or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._buffer.readline() if not raw: @@ -414,25 +446,24 @@ class HiredisParser(BaseParser): __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + _next_response: bool + def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) - self._next_response = ... self._reader: Optional[hiredis.Reader] = None self._socket_timeout: Optional[float] = None def on_connect(self, connection: "Connection"): self._stream = connection._reader - kwargs = { + kwargs: _HiredisReaderArgs = { "protocolError": InvalidResponse, "replyError": self.parse_error, } if connection.encoder.decode_responses: - kwargs.update( - encoding=connection.encoder.encoding, - errors=connection.encoder.encoding_errors, - ) + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -454,8 +485,13 @@ async def can_read(self, timeout: float): return True async def read_from_socket( - self, timeout: Optional[float] = SENTINEL, raise_on_timeout: bool = True + self, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, ): + if self._stream is None or self._reader is None: + raise RedisError("Parser already closed.") + timeout = self._socket_timeout if timeout is SENTINEL else timeout try: async with async_timeout.timeout(timeout): @@ -482,11 +518,14 @@ async def read_from_socket( return False raise ConnectionError(f"Error while reading from socket: {ex.args}") - async def read_response(self) -> EncodableT: + async def read_response(self) -> Union[EncodableT, List[EncodableT]]: if not self._stream or not self._reader: self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + response: Union[ + EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] + ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -509,7 +548,8 @@ async def read_response(self) -> EncodableT: and isinstance(response[0], ConnectionError) ): raise response[0] - return response + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) DefaultParser: Type[Union[PythonParser, HiredisParser]] @@ -580,7 +620,7 @@ def __init__( decode_responses: bool = False, parser_class: Type[BaseParser] = DefaultParser, socket_read_size: int = 65536, - health_check_interval: int = 0, + health_check_interval: float = 0, client_name: Optional[str] = None, username: Optional[str] = None, encoder_class: Type[Encoder] = Encoder, @@ -599,7 +639,7 @@ def __init__( self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout self.health_check_interval = health_check_interval - self.next_health_check = -1 + self.next_health_check: float = -1 self.ssl_context: Optional[RedisSSLContext] = None self.encoder = encoder_class(encoding, encoding_errors, decode_responses) self._reader: Optional[asyncio.StreamReader] = None @@ -714,10 +754,12 @@ async def on_connect(self): # if username and/or password are set, authenticate if self.username or self.password: + auth_args: Union[Tuple[str], Tuple[str, str]] if self.username: auth_args = (self.username, self.password or "") else: - auth_args = (self.password,) + # Mypy bug: https://github.com/python/mypy/issues/10944 + auth_args = (self.password or "",) # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH await self.send_command("AUTH", *auth_args, check_health=False) @@ -756,10 +798,10 @@ async def disconnect(self): return try: if os.getpid() == self.pid: - self._writer.close() + self._writer.close() # type: ignore[union-attr] # py3.6 doesn't have this method if hasattr(self._writer, "wait_closed"): - await self._writer.wait_closed() + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass self._reader = None @@ -790,15 +832,16 @@ async def check_health(self): except BaseException as err2: raise err2 from err - async def _send_packed_command( - self, command: Union[bytes, str, Iterable[Union[bytes, str]]] - ): + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + if self._writer is None: + raise RedisError("Connection already closed.") + self._writer.writelines(command) await self._writer.drain() async def send_packed_command( self, - command: Union[bytes, str, Iterable[Union[bytes, str]]], + command: Union[bytes, str, Iterable[bytes]], check_health: bool = True, ): """Send an already packed command to the Redis server""" @@ -876,6 +919,7 @@ def pack_command(self, *args: EncodableT) -> List[bytes]: # arguments to be sent separately, so split the first argument # manually. These arguments should be bytestrings so that they are # not encoded. + assert not isinstance(args[0], float) if isinstance(args[0], str): args = tuple(args[0].encode().split()) + args[1:] elif b" " in args[0]: @@ -954,7 +998,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.ssl_context = RedisSSLContext( + self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, cert_reqs=ssl_cert_reqs, @@ -1018,7 +1062,7 @@ def __init__( self.cert_reqs = CERT_REQS[cert_reqs] self.ca_certs = ca_certs self.check_hostname = check_hostname - self.context = None + self.context: Optional[ssl.SSLContext] = None def get(self) -> ssl.SSLContext: if not self.context: @@ -1102,7 +1146,7 @@ def _error_message(self, exception): FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") -def to_bool(value) -> bool: +def to_bool(value) -> Optional[bool]: if value is None or value == "": return None if isinstance(value, str) and value.upper() in FALSE_STRINGS: @@ -1110,16 +1154,18 @@ def to_bool(value) -> bool: return bool(value) -URL_QUERY_ARGUMENT_PARSERS = { - "db": int, - "socket_timeout": float, - "socket_connect_timeout": float, - "socket_keepalive": to_bool, - "retry_on_timeout": to_bool, - "max_connections": int, - "health_check_interval": int, - "ssl_check_hostname": to_bool, -} +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) class ConnectKwargs(TypedDict, total=False): @@ -1129,23 +1175,25 @@ class ConnectKwargs(TypedDict, total=False): host: str port: int db: int + path: str def parse_url(url: str) -> ConnectKwargs: parsed: ParseResult = urlparse(url) kwargs: ConnectKwargs = {} - for name, value in parse_qs(parsed.query).items(): - if value and len(value) > 0: - value = unquote(value[0]) + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - kwargs[name] = parser(value) + # We can't type this. + kwargs[name] = parser(value) # type: ignore[misc] except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value + kwargs[name] = value # type: ignore[misc] if parsed.username: kwargs["username"] = unquote(parsed.username) @@ -1183,7 +1231,7 @@ def parse_url(url: str) -> ConnectKwargs: return kwargs -_CP = TypeVar("_CP") +_CP = TypeVar("_CP", bound="ConnectionPool") class ConnectionPool: @@ -1428,7 +1476,7 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: if inuse_connections: - connections = chain( + connections: Iterable[Connection] = chain( self._available_connections, self._in_use_connections ) else: diff --git a/aioredis/sentinel.py b/aioredis/sentinel.py index 49bfa8b3e..c6434b84e 100644 --- a/aioredis/sentinel.py +++ b/aioredis/sentinel.py @@ -3,7 +3,7 @@ from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type from aioredis.client import Redis -from aioredis.connection import ConnectionPool, EncodableT, SSLConnection +from aioredis.connection import Connection, ConnectionPool, EncodableT, SSLConnection from aioredis.exceptions import ( ConnectionError, ReadOnlyError, @@ -107,7 +107,7 @@ def reset(self): self.master_address = None self.slave_rr_counter = None - def owns_connection(self, connection: SentinelManagedConnection): + def owns_connection(self, connection: Connection): check = not self.is_master or ( self.is_master and self.master_address == (connection.host, connection.port) ) diff --git a/aioredis/utils.py b/aioredis/utils.py index 16b48f4c7..8a1010e52 100644 --- a/aioredis/utils.py +++ b/aioredis/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, overload if TYPE_CHECKING: from aioredis import Redis @@ -13,6 +13,9 @@ HIREDIS_AVAILABLE = False +_T = TypeVar("_T") + + def from_url(url, **kwargs): """ Returns an active Redis client generated from the given database URL. @@ -37,11 +40,22 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): del self.p -def str_if_bytes(value): +# Mypy bug: https://github.com/python/mypy/issues/11005 +@overload +def str_if_bytes(value: bytes) -> str: # type: ignore[misc] + ... + + +@overload +def str_if_bytes(value: _T) -> _T: + ... + + +def str_if_bytes(value: object) -> object: return ( value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value ) -def safe_str(value): +def safe_str(value: object) -> str: return str(str_if_bytes(value))