From 5dbe8c56f03bd0f60a638ca140058305a5bb400c Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 14:37:53 +0100 Subject: [PATCH 01/19] Enable mypy in CI --- .github/workflows/ci.yml | 4 ++ .mypy.ini | 30 +++++++++ Makefile | 2 +- aioredis/compat.py | 17 ++++-- aioredis/connection.py | 129 ++++++++++++++++++++++++--------------- aioredis/sentinel.py | 4 +- 6 files changed, 130 insertions(+), 56 deletions(-) create mode 100644 .mypy.ini diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b609de39..500a3e7af 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,10 @@ jobs: uses: py-actions/py-dependency-install@v2.1.0 with: path: tests/requirements.txt + - name: Run mypy + run: | + pip install -r tests/requirements-mypy.txt + mypy - name: Run linter run: | make lint diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 000000000..e2a429bf5 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,30 @@ +[mypy] +#, docs/examples, tests +files = aioredis +check_untyped_defs = True +follow_imports_for_stubs = True +#disallow_any_decorated = True +#disallow_any_generics = True +#disallow_incomplete_defs = True +disallow_subclassing_any = True +#disallow_untyped_calls = True +disallow_untyped_decorators = True +#disallow_untyped_defs = True +implicit_reexport = False +no_implicit_optional = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_ignores = True +disallow_any_unimported = True +#warn_return_any = True + +[mypy-aioredis.client] +# TODO: Fix +ignore_errors = True + +[mypy-aioredis.lock] +# TODO: Remove once locks has been rewritten +ignore_errors = True diff --git a/Makefile b/Makefile index 75e61f0d7..27de9f595 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ spelling: $(MAKE) -C docs spelling mypy: - $(MYPY) aioredis --ignore-missing-imports + $(MYPY) test: $(PYTEST) diff --git a/aioredis/compat.py b/aioredis/compat.py index 881fda8ff..5d164fcc4 100644 --- a/aioredis/compat.py +++ b/aioredis/compat.py @@ -1,5 +1,12 @@ -# flake8: noqa -try: - from typing import Protocol, TypedDict # lgtm [py/unused-import] -except ImportError: - from typing_extensions import Protocol, TypedDict # lgtm [py/unused-import] +import sys + +if sys.version_info >= (3, 8): + from typing import ( + Protocol as Protocol, + TypedDict as TypedDict, + ) +else: + from typing_extensions import ( + Protocol as Protocol, + TypedDict as TypedDict, + ) diff --git a/aioredis/connection.py b/aioredis/connection.py index 6ec3741af..9bcc60a2b 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -1,4 +1,5 @@ import asyncio +import enum import errno import inspect import io @@ -9,15 +10,19 @@ import warnings from distutils.version import StrictVersion from itertools import chain +from types import MappingProxyType from typing import ( Any, + Callable, Iterable, List, Mapping, + NewType, Optional, Set, Tuple, Type, + TypedDict, TypeVar, Union, ) @@ -78,7 +83,12 @@ SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." -SENTINEL = object() + +class Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = Sentinel.sentinel MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." @@ -93,6 +103,13 @@ EncodableT = Union[EncodedT, DecodedT] +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + class Encoder: """Encode strings to bytes-like and decode bytes-like to strings""" @@ -117,14 +134,12 @@ def encode(self, value: EncodableT) -> EncodedT: return repr(value).encode() if not isinstance(value, str): # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ + typename = value.__class__.__name__ # type: ignore[unreachable] raise DataError( f"Invalid input of type: {typename!r}. " "Convert to a bytes, string, int or float first." ) - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - return value + return value.encode(self.encoding, self.encoding_errors) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" @@ -184,9 +199,11 @@ def parse_error(self, response: str) -> ResponseError: error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) + exception_class_or_dict = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class_or_dict, dict): + exception_class = exception_class_or_dict.get(response, ResponseError) + else: + exception_class = exception_class_or_dict return exception_class(response) return ResponseError(response) @@ -215,12 +232,12 @@ def __init__( self, stream_reader: asyncio.StreamReader, socket_read_size: int, - socket_timeout: float, + socket_timeout: Optional[float], ): - self._stream = stream_reader + self._stream: Optional[asyncio.StreamReader] = stream_reader self.socket_read_size = socket_read_size self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() + self._buffer: Optional[io.BytesIO] = io.BytesIO() # number of bytes written to the buffer from the socket self.bytes_written = 0 # number of bytes read from the buffer @@ -233,10 +250,12 @@ def length(self): async def _read_from_socket( self, length: Optional[int] = None, - timeout: Optional[float] = SENTINEL, # type: ignore + timeout: Union[float, None, Sentinel] = SENTINEL, raise_on_timeout: bool = True, ) -> bool: buf = self._buffer + if buf is None or self._stream is None: + raise RedisError("Buffer is closed.") buf.seek(self.bytes_written) marker = 0 timeout = timeout if timeout is not SENTINEL else self.socket_timeout @@ -281,6 +300,9 @@ async def read(self, length: int) -> bytes: if length > self.length: await self._read_from_socket(length - self.length) + if self._buffer is None: + raise RedisError("Buffer is closed.") + self._buffer.seek(self.bytes_read) data = self._buffer.read(length) self.bytes_read += len(data) @@ -294,6 +316,9 @@ async def read(self, length: int) -> bytes: async def readline(self) -> bytes: buf = self._buffer + if buf is None: + raise RedisError("Buffer is closed.") + buf.seek(self.bytes_read) data = buf.readline() while not data.endswith(SYM_CRLF): @@ -312,6 +337,9 @@ async def readline(self) -> bytes: return data[:-2] def purge(self): + if self._buffer is None: + raise RedisError("Buffer is closed.") + self._buffer.seek(0) self._buffer.truncate() self.bytes_written = 0 @@ -320,7 +348,7 @@ def purge(self): def close(self): try: self.purge() - self._buffer.close() + self._buffer.close() # type: ignore[union-attr] except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -344,6 +372,9 @@ def __init__(self, socket_read_size: int): def on_connect(self, connection: "Connection"): """Called when the stream connects""" self._stream = connection._reader + if self._buffer is None or self._stream is None: + raise RedisError("Buffer is closed.") + self._buffer = SocketBuffer( self._stream, self._read_size, connection.socket_timeout ) @@ -405,6 +436,7 @@ async def read_response(self) -> Union[EncodableT, ResponseError, None]: return None response = [(await self.read_response()) for _ in range(length)] if isinstance(response, bytes): + assert self.encoder is not None response = self.encoder.decode(response) return response @@ -414,25 +446,24 @@ class HiredisParser(BaseParser): __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + _next_response: bool + def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) - self._next_response = ... self._reader: Optional[hiredis.Reader] = None self._socket_timeout: Optional[float] = None def on_connect(self, connection: "Connection"): self._stream = connection._reader - kwargs = { + kwargs: _HiredisReaderArgs = { "protocolError": InvalidResponse, "replyError": self.parse_error, } if connection.encoder.decode_responses: - kwargs.update( - encoding=connection.encoder.encoding, - errors=connection.encoder.encoding_errors, - ) + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) self._next_response = False @@ -454,8 +485,11 @@ async def can_read(self, timeout: float): return True async def read_from_socket( - self, timeout: Optional[float] = SENTINEL, raise_on_timeout: bool = True + self, timeout: Union[float, None, Sentinel] = SENTINEL, raise_on_timeout: bool = True ): + if self._stream is None or self._reader is None: + raise RedisError("Parser already closed.") + timeout = self._socket_timeout if timeout is SENTINEL else timeout try: async with async_timeout.timeout(timeout): @@ -503,12 +537,6 @@ async def read_response(self) -> EncodableT: # happened if isinstance(response, ConnectionError): raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] return response @@ -580,7 +608,7 @@ def __init__( decode_responses: bool = False, parser_class: Type[BaseParser] = DefaultParser, socket_read_size: int = 65536, - health_check_interval: int = 0, + health_check_interval: float = 0, client_name: Optional[str] = None, username: Optional[str] = None, encoder_class: Type[Encoder] = Encoder, @@ -599,7 +627,7 @@ def __init__( self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout self.health_check_interval = health_check_interval - self.next_health_check = -1 + self.next_health_check: float = -1 self.ssl_context: Optional[RedisSSLContext] = None self.encoder = encoder_class(encoding, encoding_errors, decode_responses) self._reader: Optional[asyncio.StreamReader] = None @@ -715,9 +743,10 @@ async def on_connect(self): # if username and/or password are set, authenticate if self.username or self.password: if self.username: - auth_args = (self.username, self.password or "") + auth_args: Tuple[str, ...] = (self.username, self.password or "") else: - auth_args = (self.password,) + # Mypy bug: https://github.com/python/mypy/issues/10944 + auth_args = (self.password or "",) # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH await self.send_command("AUTH", *auth_args, check_health=False) @@ -756,10 +785,10 @@ async def disconnect(self): return try: if os.getpid() == self.pid: - self._writer.close() + self._writer.close() # type: ignore[union-attr] # py3.6 doesn't have this method if hasattr(self._writer, "wait_closed"): - await self._writer.wait_closed() + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass self._reader = None @@ -790,15 +819,16 @@ async def check_health(self): except BaseException as err2: raise err2 from err - async def _send_packed_command( - self, command: Union[bytes, str, Iterable[Union[bytes, str]]] - ): + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + if self._writer is None: + raise RedisError("Connection already closed.") + self._writer.writelines(command) await self._writer.drain() async def send_packed_command( self, - command: Union[bytes, str, Iterable[Union[bytes, str]]], + command: Union[bytes, str, Iterable[bytes]], check_health: bool = True, ): """Send an already packed command to the Redis server""" @@ -876,6 +906,7 @@ def pack_command(self, *args: EncodableT) -> List[bytes]: # arguments to be sent separately, so split the first argument # manually. These arguments should be bytestrings so that they are # not encoded. + assert not isinstance(args[0], float) if isinstance(args[0], str): args = tuple(args[0].encode().split()) + args[1:] elif b" " in args[0]: @@ -954,7 +985,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.ssl_context = RedisSSLContext( + self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, cert_reqs=ssl_cert_reqs, @@ -1018,7 +1049,7 @@ def __init__( self.cert_reqs = CERT_REQS[cert_reqs] self.ca_certs = ca_certs self.check_hostname = check_hostname - self.context = None + self.context: Optional[ssl.SSLContext] = None def get(self) -> ssl.SSLContext: if not self.context: @@ -1102,7 +1133,7 @@ def _error_message(self, exception): FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") -def to_bool(value) -> bool: +def to_bool(value) -> Optional[bool]: if value is None or value == "": return None if isinstance(value, str) and value.upper() in FALSE_STRINGS: @@ -1110,7 +1141,7 @@ def to_bool(value) -> bool: return bool(value) -URL_QUERY_ARGUMENT_PARSERS = { +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType({ "db": int, "socket_timeout": float, "socket_connect_timeout": float, @@ -1119,7 +1150,7 @@ def to_bool(value) -> bool: "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, -} +}) class ConnectKwargs(TypedDict, total=False): @@ -1129,23 +1160,25 @@ class ConnectKwargs(TypedDict, total=False): host: str port: int db: int + path: str def parse_url(url: str) -> ConnectKwargs: parsed: ParseResult = urlparse(url) kwargs: ConnectKwargs = {} - for name, value in parse_qs(parsed.query).items(): - if value and len(value) > 0: - value = unquote(value[0]) + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - kwargs[name] = parser(value) + # We can't type this. + kwargs[name] = parser(value) # type: ignore[misc] except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value + kwargs[name] = value # type: ignore[misc] if parsed.username: kwargs["username"] = unquote(parsed.username) @@ -1183,7 +1216,7 @@ def parse_url(url: str) -> ConnectKwargs: return kwargs -_CP = TypeVar("_CP") +_CP = TypeVar("_CP", bound="ConnectionPool") class ConnectionPool: @@ -1428,7 +1461,7 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: if inuse_connections: - connections = chain( + connections: Iterable[Connection] = chain( self._available_connections, self._in_use_connections ) else: diff --git a/aioredis/sentinel.py b/aioredis/sentinel.py index 49bfa8b3e..c6434b84e 100644 --- a/aioredis/sentinel.py +++ b/aioredis/sentinel.py @@ -3,7 +3,7 @@ from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type from aioredis.client import Redis -from aioredis.connection import ConnectionPool, EncodableT, SSLConnection +from aioredis.connection import Connection, ConnectionPool, EncodableT, SSLConnection from aioredis.exceptions import ( ConnectionError, ReadOnlyError, @@ -107,7 +107,7 @@ def reset(self): self.master_address = None self.slave_rr_counter = None - def owns_connection(self, connection: SentinelManagedConnection): + def owns_connection(self, connection: Connection): check = not self.is_master or ( self.is_master and self.master_address == (connection.host, connection.port) ) From 9656b421e36a87318b487f0c03ce5b3e35026bc2 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 14:39:45 +0100 Subject: [PATCH 02/19] Add change file. --- CHANGES/1101.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/1101.misc diff --git a/CHANGES/1101.misc b/CHANGES/1101.misc new file mode 100644 index 000000000..00d0b7afd --- /dev/null +++ b/CHANGES/1101.misc @@ -0,0 +1 @@ +Enable Mypy in CI. From fe7547dc17077c85d0f73e7eea97e9d8dcd2e0b8 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 14:52:03 +0100 Subject: [PATCH 03/19] Lint --- aioredis/compat.py | 12 ++++-------- aioredis/connection.py | 2 -- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/aioredis/compat.py b/aioredis/compat.py index 5d164fcc4..ae636028d 100644 --- a/aioredis/compat.py +++ b/aioredis/compat.py @@ -1,12 +1,8 @@ import sys if sys.version_info >= (3, 8): - from typing import ( - Protocol as Protocol, - TypedDict as TypedDict, - ) + from typing import Protocol, TypedDict else: - from typing_extensions import ( - Protocol as Protocol, - TypedDict as TypedDict, - ) + from typing_extensions import Protocol, TypedDict + +__all__ = ("Protocol", "TypedDict") diff --git a/aioredis/connection.py b/aioredis/connection.py index 9bcc60a2b..c8bd94077 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -17,12 +17,10 @@ Iterable, List, Mapping, - NewType, Optional, Set, Tuple, Type, - TypedDict, TypeVar, Union, ) From 68add8bbf45c962a96736225de47e54f32ef1a13 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 15:13:48 +0100 Subject: [PATCH 04/19] Show diff in CI. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 27de9f595..d1ba50491 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ init-hooks: pre-commit install-hooks lint: init-hooks - pre-commit run --all-files + pre-commit run --all-files --show-diff-on-failure devel: aioredis.egg-info init-hooks pip install -U pip From f8d83df5934816c5c4cc78b6ec703ad0713038f6 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 15:16:46 +0100 Subject: [PATCH 05/19] Black --- aioredis/connection.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index c8bd94077..0dd011461 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -483,7 +483,9 @@ async def can_read(self, timeout: float): return True async def read_from_socket( - self, timeout: Union[float, None, Sentinel] = SENTINEL, raise_on_timeout: bool = True + self, + timeout: Union[float, None, Sentinel] = SENTINEL, + raise_on_timeout: bool = True, ): if self._stream is None or self._reader is None: raise RedisError("Parser already closed.") @@ -1139,16 +1141,18 @@ def to_bool(value) -> Optional[bool]: return bool(value) -URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType({ - "db": int, - "socket_timeout": float, - "socket_connect_timeout": float, - "socket_keepalive": to_bool, - "retry_on_timeout": to_bool, - "max_connections": int, - "health_check_interval": int, - "ssl_check_hostname": to_bool, -}) +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) class ConnectKwargs(TypedDict, total=False): From ed06034a1c1595327ae5dc36ad484209dc928799 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 16:21:15 +0100 Subject: [PATCH 06/19] Readd "unreachable" code. --- aioredis/connection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aioredis/connection.py b/aioredis/connection.py index 0dd011461..4a2f182a3 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -537,6 +537,12 @@ async def read_response(self) -> EncodableT: # happened if isinstance(response, ConnectionError): raise response + elif ( + isinstance(response, list) # type: ignore[unreachable] + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] return response From bea076bfab3d2b807821fff9606e286004876111 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 16:43:06 +0100 Subject: [PATCH 07/19] Fix typing for ConnectionError. --- aioredis/connection.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index 4a2f182a3..74877bd8b 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -23,6 +23,7 @@ Type, TypeVar, Union, + cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -214,7 +215,7 @@ def on_connect(self, connection: "Connection"): async def can_read(self, timeout: float) -> bool: raise NotImplementedError() - async def read_response(self) -> Union[EncodableT, ResponseError, None]: + async def read_response(self) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: raise NotImplementedError() @@ -516,11 +517,12 @@ async def read_from_socket( return False raise ConnectionError(f"Error while reading from socket: {ex.args}") - async def read_response(self) -> EncodableT: + async def read_response(self) -> Union[EncodableT, List[EncodableT]]: if not self._stream or not self._reader: self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + response: Union[EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -538,12 +540,13 @@ async def read_response(self) -> EncodableT: if isinstance(response, ConnectionError): raise response elif ( - isinstance(response, list) # type: ignore[unreachable] + isinstance(response, list) and response and isinstance(response[0], ConnectionError) ): raise response[0] - return response + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) DefaultParser: Type[Union[PythonParser, HiredisParser]] From 861753848c10a8052f871e2c4bd266c94cebda07 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 8 Aug 2021 17:05:11 +0100 Subject: [PATCH 08/19] Black --- aioredis/connection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index 74877bd8b..0f41bac48 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -215,7 +215,9 @@ def on_connect(self, connection: "Connection"): async def can_read(self, timeout: float) -> bool: raise NotImplementedError() - async def read_response(self) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + async def read_response( + self, + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: raise NotImplementedError() @@ -522,7 +524,9 @@ async def read_response(self) -> Union[EncodableT, List[EncodableT]]: self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response: Union[EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]] + response: Union[ + EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] + ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response From e8e21a476ffda2e776eb0293881866549f52a775 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 16:23:21 +0100 Subject: [PATCH 09/19] Add client.py --- .mypy.ini | 4 - aioredis/client.py | 276 ++++++++++++++++++++++++--------------------- 2 files changed, 149 insertions(+), 131 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index e2a429bf5..d2bc5b09e 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -21,10 +21,6 @@ warn_unused_ignores = True disallow_any_unimported = True #warn_return_any = True -[mypy-aioredis.client] -# TODO: Fix -ignore_errors = True - [mypy-aioredis.lock] # TODO: Remove once locks has been rewritten ignore_errors = True diff --git a/aioredis/client.py b/aioredis/client.py index 972270b41..863c086e4 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -7,6 +7,7 @@ import warnings from itertools import chain from typing import ( + AbstractSet, Any, AsyncIterator, Awaitable, @@ -15,6 +16,7 @@ Iterable, List, Mapping, + MutableMapping, NoReturn, Optional, Sequence, @@ -23,6 +25,8 @@ Type, TypeVar, Union, + ValuesView, + cast, ) from aioredis.compat import Protocol, TypedDict @@ -56,7 +60,7 @@ KeyT = _StringLikeT # Main redis key space PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands -KeysT = Union[KeyT, Iterable[KeyT]] +KeysT = Union[KeyT, Sequence[KeyT]] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name @@ -69,27 +73,35 @@ # type signature because they will all be required to be the same key type. AnyKeyT = TypeVar("AnyKeyT", bytes, str, memoryview) AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) -AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +AnyChannelT = ChannelT +PubSubHandler = Callable[[Dict[str, str]], None] SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[EncodableT, object]) -def list_or_args(keys: KeysT, args: Optional[KeysT]) -> KeysT: + +def list_or_args(keys: Union[_KeyT, Iterable[_KeyT]], args: Optional[Iterable[_ArgT]]) -> List[Union[_KeyT, _ArgT]]: # returns a single new list combining keys and args + key_list: List[Union[_KeyT, _ArgT]] try: - iter(keys) + iter(keys) # type: ignore[arg-type] + keys = cast(Iterable[_KeyT], keys) # a string or bytes instance can be iterated, but indicates # keys wasn't passed as a list if isinstance(keys, (bytes, str)): - keys = [keys] + key_list = [keys] else: - keys = list(keys) + key_list = list(keys) except TypeError: - keys = [keys] + key_list = [cast(memoryview, keys)] if args: - keys.extend(args) - return keys + key_list.extend(args) + return key_list def timestamp_to_datetime(response): @@ -161,7 +173,7 @@ def parse_object(response, infotype): def parse_info(response): """Parse the result of Redis's INFO command into a Python dict""" - info = {} + info: Dict[str, Any] = {} response = str_if_bytes(response) def get_value(value): @@ -449,7 +461,7 @@ def parse_zscan(response, **options): def parse_slowlog_get(response, **options): - space = " " if options.get("decode_responses", False) else b" " + space: Union[str, bytes] = " " if options.get("decode_responses", False) else b" " return [ { "id": item[0], @@ -510,7 +522,7 @@ def parse_georadius_generic(response, **options): # just a bunch of places return response_list - cast = { + cast: Dict[str, Callable] = { "withdist": float, "withcoord": lambda ll: (float(ll[0]), float(ll[1])), "withhash": int, @@ -892,9 +904,9 @@ def __init__( connection_pool = ConnectionPool(**kwargs) self.connection_pool = connection_pool self.single_connection_client = single_connection_client - self.connection = None + self.connection: Optional[Connection] = None - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks: MutableMapping[str, ResponseCallbackT] = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) def __repr__(self): return f"{self.__class__.__name__}<{self.connection_pool!r}>" @@ -902,7 +914,7 @@ def __repr__(self): def __await__(self): return self.initialize().__await__() - async def initialize(self): + async def initialize(self: _RedisT) -> _RedisT: if self.single_connection_client and self.connection is None: self.connection = await self.connection_pool.get_connection("_") return self @@ -1033,7 +1045,8 @@ def client(self) -> "Redis": connection_pool=self.connection_pool, single_connection_client=True ) - __aenter__ = initialize + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() async def __aexit__(self, exc_type, exc_value, traceback): await self.close() @@ -1085,6 +1098,8 @@ async def parse_response( return options[EMPTY_RESPONSE] raise if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) retval = self.response_callbacks[command_name](response, **options) return await retval if inspect.isawaitable(retval) else retval return response @@ -1261,9 +1276,9 @@ def acl_setuser( # noqa: C901 if passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list - passwords = list_or_args(passwords, []) - for i, password in enumerate(passwords): - password = encoder.encode(password) + converted_passwords = list_or_args(passwords, []) + for i, raw_password in enumerate(converted_passwords): + password = encoder.encode(raw_password) if password.startswith(b"+"): pieces.append(b">%s" % password[1:]) elif password.startswith(b"-"): @@ -1277,9 +1292,9 @@ def acl_setuser( # noqa: C901 if hashed_passwords: # as most users will have only one password, allow remove_passwords # to be specified as a simple string or a list - hashed_passwords = list_or_args(hashed_passwords, []) - for i, hashed_password in enumerate(hashed_passwords): - hashed_password = encoder.encode(hashed_password) + parsed_hashed_passwords = list_or_args(hashed_passwords, []) + for i, raw_hashed_password in enumerate(parsed_hashed_passwords): + hashed_password = encoder.encode(raw_hashed_password) if hashed_password.startswith(b"+"): pieces.append(b"#%s" % hashed_password[1:]) elif hashed_password.startswith(b"-"): @@ -1294,8 +1309,8 @@ def acl_setuser( # noqa: C901 pieces.append(b"nopass") if categories: - for category in categories: - category = encoder.encode(category) + for raw_category in categories: + category = encoder.encode(raw_category) # categories can be prefixed with one of (+@, +, -@, -) if category.startswith(b"+@"): pieces.append(category) @@ -1311,8 +1326,8 @@ def acl_setuser( # noqa: C901 'prefixed with "+" or "-"' ) if commands: - for cmd in commands: - cmd = encoder.encode(cmd) + for raw_cmd in commands: + cmd = encoder.encode(raw_cmd) if not cmd.startswith(b"+") and not cmd.startswith(b"-"): raise DataError( f'Command "{encoder.decode(cmd, force=True)}" must be ' @@ -1321,8 +1336,8 @@ def acl_setuser( # noqa: C901 pieces.append(cmd) if keys: - for key in keys: - key = encoder.encode(key) + for raw_key in keys: + key = encoder.encode(raw_key) pieces.append(b"~%s" % key) return self.execute_command("ACL SETUSER", *pieces) @@ -1367,7 +1382,7 @@ def client_kill_filter( will not get killed even if it is identified by one of the filter options. If skipme is not provided, the server defaults to skipme=True """ - args = [] + args: List[Union[bytes, str]] = [] if _type is not None: client_types = ("normal", "master", "slave", "pubsub") if str(_type).lower() not in client_types: @@ -1639,7 +1654,7 @@ def sentinel_slaves(self, service_name: str) -> Awaitable: """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) - def shutdown(self, save: bool = False, nosave: bool = False) -> Awaitable: + def shutdown(self, save: bool = False, nosave: bool = False) -> None: """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -1766,12 +1781,12 @@ def bitpos( raise DataError("bit must be 0 or 1") params = [key, bit] - start is not None and params.append(start) - - if start is not None and end is not None: - params.append(end) - elif start is None and end is not None: - raise DataError("start argument is not set, " "when end is specified") + if start is not None: + params.append(start) + if end is not None: + params.append(end) + elif end is not None: + raise DataError("start argument is not set, when end is specified") return self.execute_command("BITPOS", *params) def decr(self, name: KeyT, amount: int = 1) -> Awaitable: @@ -1878,11 +1893,11 @@ def mget(self, keys: KeysT, *args: EncodableT) -> Awaitable: """ Returns a list of values ordered identically to ``keys`` """ - args = list_or_args(keys, args) + encoded_args = list_or_args(keys, args) options: Dict[str, Union[EncodableT, Iterable[EncodableT]]] = {} - if not args: + if not encoded_args: options[EMPTY_RESPONSE] = [] - return self.execute_command("MGET", *args, **options) + return self.execute_command("MGET", *encoded_args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> Awaitable: """ @@ -2117,11 +2132,7 @@ def blpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BLPOP", *keys) + return self.execute_command("BLPOP", *list_or_args(keys, (timeout,))) def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ @@ -2134,11 +2145,7 @@ def brpop(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BRPOP", *keys) + return self.execute_command("BRPOP", *list_or_args(keys, (timeout,))) def brpoplpush(self, src: KeyT, dst: KeyT, timeout: int = 0) -> Awaitable: """ @@ -2149,8 +2156,6 @@ def brpoplpush(self, src: KeyT, dst: KeyT, timeout: int = 0) -> Awaitable: seconds elapse, whichever is first. A ``timeout`` value of 0 blocks forever. """ - if timeout is None: - timeout = 0 return self.execute_command("BRPOPLPUSH", src, dst, timeout) def lindex(self, name: KeyT, index: int) -> Awaitable: @@ -2362,8 +2367,10 @@ def sort( "must be specified and contain at least " "two keys" ) + options: Dict[str, Optional[int]] = {"groups": len(get)} + else: + options = {"groups": None} - options = {"groups": len(get) if groups else None} return self.execute_command("SORT", *pieces, **options) # SCAN COMMANDS @@ -2417,10 +2424,10 @@ async def scan_iter( HASH, LIST, SET, STREAM, STRING, ZSET Additionally, Redis modules can expose other types as well. """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.scan( - cursor=cursor, match=match, count=count, _type=_type + cursor=cursor or 0, match=match, count=count, _type=_type ) for d in data: yield d @@ -2458,10 +2465,10 @@ async def sscan_iter( ``count`` allows for hint the minimum number of returns """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.sscan( - name, cursor=cursor, match=match, count=count + name, cursor=cursor or 0, match=match, count=count ) for d in data: yield d @@ -2499,10 +2506,10 @@ async def hscan_iter( ``count`` allows for hint the minimum number of returns """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.hscan( - name, cursor=cursor, match=match, count=count + name, cursor=cursor or 0, match=match, count=count ) for it in data.items(): yield it @@ -2550,11 +2557,11 @@ async def zscan_iter( ``score_cast_func`` a callable used to cast the score return value """ - cursor = "0" + cursor = None while cursor != 0: cursor, data = await self.zscan( name, - cursor=cursor, + cursor=cursor or 0, match=match, count=count, score_cast_func=score_cast_func, @@ -2573,29 +2580,29 @@ def scard(self, name: KeyT) -> Awaitable: def sdiff(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the difference of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SDIFF", *parsed_args) def sdiffstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the difference of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SDIFFSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SDIFFSTORE", dest, *parsed_args) def sinter(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the intersection of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SINTER", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SINTER", *parsed_args) def sinterstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the intersection of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SINTERSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SINTERSTORE", dest, *parsed_args) def sismember(self, name: KeyT, value: EncodableT) -> Awaitable: """Return a boolean indicating if ``value`` is a member of set ``name``""" @@ -2631,16 +2638,16 @@ def srem(self, name: KeyT, *values: EncodableT) -> Awaitable: def sunion(self, keys: KeysT, *args: EncodableT) -> Awaitable: """Return the union of sets specified by ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("SUNION", *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SUNION", *parsed_args) def sunionstore(self, dest: KeyT, keys: KeysT, *args: EncodableT) -> Awaitable: """ Store the union of sets specified by ``keys`` into a new set named ``dest``. Returns the number of keys in the new set. """ - args = list_or_args(keys, args) - return self.execute_command("SUNIONSTORE", dest, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("SUNIONSTORE", dest, *parsed_args) # STREAMS COMMANDS def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> Awaitable: @@ -2847,9 +2854,9 @@ def xpending_range( self, name: KeyT, groupname: GroupT, - min: StreamIdT, - max: StreamIdT, - count: int, + min: Optional[StreamIdT], + max: Optional[StreamIdT], + count: Optional[int], consumername: Optional[ConsumerT] = None, ) -> Awaitable: """ @@ -3141,11 +3148,8 @@ def bzpopmax(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 - keys = list_or_args(keys, None) - keys.append(timeout) - return self.execute_command("BZPOPMAX", *keys) + parsed_keys = list_or_args(keys, (timeout,)) + return self.execute_command("BZPOPMAX", *parsed_keys) def bzpopmin(self, keys: KeysT, timeout: int = 0) -> Awaitable: """ @@ -3158,8 +3162,6 @@ def bzpopmin(self, keys: KeysT, timeout: int = 0) -> Awaitable: If timeout is 0, then block indefinitely. """ - if timeout is None: - timeout = 0 klist: List[EncodableT] = list_or_args(keys, None) klist.append(timeout) return self.execute_command("BZPOPMIN", *klist) @@ -3397,11 +3399,14 @@ def _zaggregate( aggregate: Optional[str] = None, ) -> Awaitable: pieces: List[EncodableT] = [command, dest, len(keys)] - if isinstance(keys, dict): - keys, weights = keys.keys(), keys.values() + key_names: Union[Sequence[KeyT], AbstractSet[AnyKeyT]] + weights: Optional[ValuesView[float]] + if isinstance(keys, Mapping): + key_names, weights = keys.keys(), keys.values() else: + key_names = keys weights = None - pieces.extend(keys) + pieces.extend(key_names) if weights: pieces.append(b"WEIGHTS") pieces.extend(weights) @@ -3476,7 +3481,7 @@ def hset( """ if key is None and not mapping: raise DataError("'hset' with no key value pairs") - items: List[EncodableT] = [] + items: List[Union[FieldT, Optional[EncodableT]]] = [] if key is not None: items.extend((key, value)) if mapping: @@ -3505,15 +3510,15 @@ def hmset(self, name: KeyT, mapping: Mapping[AnyFieldT, EncodableT]) -> Awaitabl ) if not mapping: raise DataError("'hmset' with 'mapping' of length 0") - items = [] + items: List[Union[AnyFieldT, EncodableT]] = [] for pair in mapping.items(): items.extend(pair) return self.execute_command("HMSET", name, *items) - def hmget(self, name: KeyT, keys: Sequence[FieldT], *args: FieldT) -> Awaitable: + def hmget(self, name: KeyT, keys: Sequence[KeyT], *args: FieldT) -> Awaitable: """Returns a list of values ordered identically to ``keys``""" - args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args) + parsed_args = list_or_args(keys, args) + return self.execute_command("HMGET", name, *parsed_args) def hvals(self, name: KeyT) -> Awaitable: """Return the list of values within hash ``name``""" @@ -3749,7 +3754,7 @@ def georadiusbymember( def _georadiusgeneric( self, command: str, *args: EncodableT, **kwargs: Optional[EncodableT] ) -> Awaitable: - pieces: List[EncodableT] = list(args) + pieces: List[Optional[EncodableT]] = list(args) if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): raise DataError("GEORADIUS invalid unit") elif kwargs["unit"]: @@ -3844,6 +3849,7 @@ async def connect(self): async def __aenter__(self): await self.connect() + self.connection = cast(Connection, self.connection) # Connected above. await self.connection.send_command("MONITOR") # check that monitor returns 'OK', but don't return it to user response = await self.connection.read_response() @@ -3852,17 +3858,22 @@ async def __aenter__(self): return self async def __aexit__(self, *args): + assert self.connection is not None await self.connection.disconnect() await self.connection_pool.release(self.connection) async def next_command(self) -> MonitorCommandInfo: """Parse the response from a monitor command""" + if self.connection is None: + raise RedisError("Connection already closed.") await self.connect() response = await self.connection.read_response() if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) + if m is None: + raise RedisError("Invalid command received.") db_id, client_info, command = m.groups() command = " ".join(self.command_re.findall(command)) # Redis escapes double quotes because each piece of the command @@ -3919,7 +3930,7 @@ def __init__( self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages - self.connection = None + self.connection: Optional[Connection] = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = self.connection_pool.get_encoder() @@ -3933,10 +3944,10 @@ def __init__( b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] - self.channels = {} - self.pending_unsubscribe_channels = set() - self.patterns = {} - self.pending_unsubscribe_patterns = set() + self.channels: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_channels: Set[ChannelT] = set() + self.patterns: Dict[ChannelT, PubSubHandler] = {} + self.pending_unsubscribe_patterns: Set[ChannelT] = set() self._lock = asyncio.Lock() async def __aenter__(self): @@ -4056,7 +4067,7 @@ async def check_health(self): "PING", self.HEALTH_CHECK_MESSAGE, check_health=False ) - def _normalize_keys(self, data: Mapping[AnyChannelT, EncodableT]): + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: """ normalize channel/pattern names to be either bytes or strings based on whether responses are automatically decoded. this saves us @@ -4064,9 +4075,9 @@ def _normalize_keys(self, data: Mapping[AnyChannelT, EncodableT]): """ encode = self.encoder.encode decode = self.encoder.decode - return {decode(encode(k)): v for k, v in data.items()} + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] - async def psubscribe(self, *args: ChannelT, **kwargs: Callable): + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): """ Subscribe to channel patterns. Patterns supplied as keyword arguments expect a pattern name as the key and a callable as the value. A @@ -4074,15 +4085,17 @@ async def psubscribe(self, *args: ChannelT, **kwargs: Callable): received on that pattern rather than producing a message via ``listen()``. """ - if args: - args = list_or_args(args[0], args[1:]) - new_patterns: Dict[ChannelT, Optional[Callable]] = dict.fromkeys(args) - new_patterns.update(kwargs) + # Mixed types.. + parsed_args = list_or_args(args[0], args[1:]) if args else args # type: ignore[arg-type] + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - new_patterns = self._normalize_keys(new_patterns) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns = self._normalize_keys(new_patterns) # type: ignore[type-var] self.patterns.update(new_patterns) self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val @@ -4093,12 +4106,15 @@ def punsubscribe(self, *args: EncodableT) -> Awaitable: all patterns. """ if args: - args = list_or_args(args[0], args[1:]) - patterns = self._normalize_keys(dict.fromkeys(args)) + # Mixed types... + parsed_args: Sequence[EncodableT] = list_or_args(args[0], args[1:]) + patterns: Mapping[EncodableT, Optional[PubSubHandler]] = self._normalize_keys(dict.fromkeys(parsed_args)) else: - patterns = self.patterns + parsed_args = args + # Mypy bug: https://github.com/python/mypy/issues/10970 + patterns = self.patterns # type: ignore[assignment] self.pending_unsubscribe_patterns.update(patterns) - return self.execute_command("PUNSUBSCRIBE", *args) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) async def subscribe(self, *args: ChannelT, **kwargs: Callable): """ @@ -4108,15 +4124,16 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): that channel rather than producing a message via ``listen()`` or ``get_message()``. """ - if args: - args = list_or_args(args[0], args[1:]) - new_channels = dict.fromkeys(args) - new_channels.update(kwargs) + parsed_args = list_or_args(args[0], args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - new_channels = self._normalize_keys(new_channels) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels = self._normalize_keys(new_channels) # type: ignore[type-var] self.channels.update(new_channels) self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val @@ -4127,12 +4144,12 @@ def unsubscribe(self, *args) -> Awaitable: all channels """ if args: - args = list_or_args(args[0], args[1:]) - channels = self._normalize_keys(dict.fromkeys(args)) + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) else: channels = self.channels self.pending_unsubscribe_channels.update(channels) - return self.execute_command("UNSUBSCRIBE", *args) + return self.execute_command("UNSUBSCRIBE", *parsed_args) async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" @@ -4225,7 +4242,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): async def run( self, *, - exception_handler: "PSWorkerThreadExcHandlerT" = None, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, ) -> None: """Process pub/sub messages using registered callbacks. @@ -4309,21 +4326,21 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] def __init__( self, connection_pool: ConnectionPool, - response_callbacks: Mapping[str, ResponseCallbackT], + response_callbacks: MutableMapping[str, ResponseCallbackT], transaction: bool, shard_hint: Optional[str], ): self.connection_pool = connection_pool self.connection = None self.response_callbacks = response_callbacks - self.transaction = transaction + self.is_transaction = transaction self.shard_hint = shard_hint self.watching = False self.command_stack: CommandStackT = [] self.scripts: Set[Script] = set() self.explicit_transaction = False - async def __aenter__(self) -> "Pipeline": + async def __aenter__(self: _RedisT) -> _RedisT: return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -4412,6 +4429,7 @@ async def immediate_execute_command(self, *args, **options): command_name, self.shard_hint ) self.connection = conn + conn = cast(Connection, conn) try: await conn.send_command(*args) return await self.parse_response(conn, command_name, **options) @@ -4462,7 +4480,7 @@ def pipeline_execute_command(self, *args, **options): async def _execute_transaction( self, connection: Connection, commands: CommandStackT, raise_on_error ): - cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + cmds: Iterable[CommandT] = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] ) @@ -4508,7 +4526,8 @@ async def _execute_transaction( response.insert(i, e) if len(response) != len(commands): - await self.connection.disconnect() + if self.connection: + await self.connection.disconnect() raise ResponseError( "Wrong number of response items from pipeline execution" ) from None @@ -4591,7 +4610,7 @@ async def execute(self, raise_on_error: bool = True): return [] if self.scripts: await self.load_scripts() - if self.transaction or self.explicit_transaction: + if self.is_transaction or self.explicit_transaction: execute = self._execute_transaction else: execute = self._execute_pipeline @@ -4602,6 +4621,7 @@ async def execute(self, raise_on_error: bool = True): # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn + conn = cast(Connection, conn) try: return await execute(conn, stack, raise_on_error) @@ -4646,8 +4666,10 @@ def __init__(self, registered_client: Redis, script: ScriptTextT): # We need the encoding from the client in order to generate an # accurate byte representation of the script encoder = registered_client.connection_pool.get_encoder() - script = encoder.encode(script) - self.sha = hashlib.sha1(script).hexdigest() + script_bytes = encoder.encode(script) + else: + script_bytes = script + self.sha = hashlib.sha1(script_bytes).hexdigest() async def __call__( self, @@ -4681,7 +4703,7 @@ class BitFieldOperation: Command builder for BITFIELD commands. """ - def __init__(self, client: Redis, key: str, default_overflow: Optional[str] = None): + def __init__(self, client: Redis, key: KeyT, default_overflow: Optional[str] = None): self.client = client self.key = key self._default_overflow = default_overflow @@ -4765,7 +4787,7 @@ def set(self, fmt: str, offset: BitfieldOffsetT, value: int): @property def command(self): - cmd = ["BITFIELD", self.key] + cmd: List[EncodableT] = ["BITFIELD", self.key] for ops in self.operations: cmd.extend(ops) return cmd From e7d456d24bb423fffc030f49546b021995c173d8 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 17:37:52 +0100 Subject: [PATCH 10/19] Fixes --- aioredis/client.py | 31 +++++++++++++------------------ aioredis/utils.py | 16 +++++++++++++--- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 863c086e4..c372307d5 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -82,7 +82,7 @@ _KeyT = TypeVar("_KeyT", bound=KeyT) _ArgT = TypeVar("_ArgT", KeyT, EncodableT) _RedisT = TypeVar("_RedisT", bound="Redis") -_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[EncodableT, object]) +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) def list_or_args(keys: Union[_KeyT, Iterable[_KeyT]], args: Optional[Iterable[_ArgT]]) -> List[Union[_KeyT, _ArgT]]: @@ -906,7 +906,7 @@ def __init__( self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None - self.response_callbacks: MutableMapping[str, ResponseCallbackT] = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) def __repr__(self): return f"{self.__class__.__name__}<{self.connection_pool!r}>" @@ -4085,8 +4085,7 @@ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): received on that pattern rather than producing a message via ``listen()``. """ - # Mixed types.. - parsed_args = list_or_args(args[0], args[1:]) if args else args # type: ignore[arg-type] + parsed_args = list_or_args((args[0],), args[1:]) if args else args new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) # Mypy bug: https://github.com/python/mypy/issues/10970 new_patterns.update(kwargs) # type: ignore[arg-type] @@ -4094,25 +4093,22 @@ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): # update the patterns dict AFTER we send the command. we don't want to # subscribe twice to these patterns, once for the command and again # for the reconnection. - # Mypy bug: https://github.com/python/mypy/issues/10970 - new_patterns = self._normalize_keys(new_patterns) # type: ignore[type-var] + new_patterns = self._normalize_keys(new_patterns) self.patterns.update(new_patterns) self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val - def punsubscribe(self, *args: EncodableT) -> Awaitable: + def punsubscribe(self, *args: ChannelT) -> Awaitable: """ Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ if args: - # Mixed types... - parsed_args: Sequence[EncodableT] = list_or_args(args[0], args[1:]) - patterns: Mapping[EncodableT, Optional[PubSubHandler]] = self._normalize_keys(dict.fromkeys(parsed_args)) + parsed_args = list_or_args((args[0],), args[1:]) + patterns: Iterable[ChannelT] = self._normalize_keys(dict.fromkeys(parsed_args)).keys() else: - parsed_args = args - # Mypy bug: https://github.com/python/mypy/issues/10970 - patterns = self.patterns # type: ignore[assignment] + parsed_args = [] + patterns = self.patterns self.pending_unsubscribe_patterns.update(patterns) return self.execute_command("PUNSUBSCRIBE", *parsed_args) @@ -4124,7 +4120,7 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): that channel rather than producing a message via ``listen()`` or ``get_message()``. """ - parsed_args = list_or_args(args[0], args[1:]) if args else () + parsed_args = list_or_args((args[0],), args[1:]) if args else () new_channels = dict.fromkeys(parsed_args) # Mypy bug: https://github.com/python/mypy/issues/10970 new_channels.update(kwargs) # type: ignore[arg-type] @@ -4132,8 +4128,7 @@ async def subscribe(self, *args: ChannelT, **kwargs: Callable): # update the channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. - # Mypy bug: https://github.com/python/mypy/issues/10970 - new_channels = self._normalize_keys(new_channels) # type: ignore[type-var] + new_channels = self._normalize_keys(new_channels) self.channels.update(new_channels) self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val @@ -4326,7 +4321,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] def __init__( self, connection_pool: ConnectionPool, - response_callbacks: MutableMapping[str, ResponseCallbackT], + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], transaction: bool, shard_hint: Optional[str], ): @@ -4575,7 +4570,7 @@ def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): self.annotate_exception(r, i + 1, commands[i][0]) raise r - def annotate_exception(self, exception: Exception, number: int, command: str): + def annotate_exception(self, exception: Exception, number: int, command: Iterable[object]): cmd = " ".join(map(safe_str, command)) msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" exception.args = (msg,) + exception.args[1:] diff --git a/aioredis/utils.py b/aioredis/utils.py index 16b48f4c7..730e5a729 100644 --- a/aioredis/utils.py +++ b/aioredis/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union, TypeVar, overload if TYPE_CHECKING: from aioredis import Redis @@ -13,6 +13,9 @@ HIREDIS_AVAILABLE = False +_T = TypeVar("_T") + + def from_url(url, **kwargs): """ Returns an active Redis client generated from the given database URL. @@ -37,11 +40,18 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): del self.p -def str_if_bytes(value): +# Mypy bug: https://github.com/python/mypy/issues/11005 +@overload +def str_if_bytes(value: bytes) -> str: # type: ignore[misc] + ... +@overload +def str_if_bytes(value: _T) -> _T: + ... +def str_if_bytes(value: object) -> object: return ( value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value ) -def safe_str(value): +def safe_str(value: object) -> str: return str(str_if_bytes(value)) From 7011de68a6e54051870ecbb563e1f475f11fd444 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 17:47:31 +0100 Subject: [PATCH 11/19] Black --- aioredis/client.py | 23 +++++++++++++++++------ aioredis/utils.py | 4 ++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index c372307d5..80c20d6c4 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -85,7 +85,9 @@ _NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) -def list_or_args(keys: Union[_KeyT, Iterable[_KeyT]], args: Optional[Iterable[_ArgT]]) -> List[Union[_KeyT, _ArgT]]: +def list_or_args( + keys: Union[_KeyT, Iterable[_KeyT]], args: Optional[Iterable[_ArgT]] +) -> List[Union[_KeyT, _ArgT]]: # returns a single new list combining keys and args key_list: List[Union[_KeyT, _ArgT]] try: @@ -782,6 +784,8 @@ class Redis: "ZSCAN": parse_zscan, } + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + @classmethod def from_url(cls, url: str, **kwargs): """ @@ -906,7 +910,7 @@ def __init__( self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None - self.response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) def __repr__(self): return f"{self.__class__.__name__}<{self.connection_pool!r}>" @@ -4103,9 +4107,10 @@ def punsubscribe(self, *args: ChannelT) -> Awaitable: Unsubscribe from the supplied patterns. If empty, unsubscribe from all patterns. """ + patterns: Iterable[ChannelT] if args: parsed_args = list_or_args((args[0],), args[1:]) - patterns: Iterable[ChannelT] = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() else: parsed_args = [] patterns = self.patterns @@ -4475,7 +4480,9 @@ def pipeline_execute_command(self, *args, **options): async def _execute_transaction( self, connection: Connection, commands: CommandStackT, raise_on_error ): - cmds: Iterable[CommandT] = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + cmds: Iterable[CommandT] = chain( + [(("MULTI",), {})], commands, [(("EXEC",), {})] + ) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] ) @@ -4570,7 +4577,9 @@ def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): self.annotate_exception(r, i + 1, commands[i][0]) raise r - def annotate_exception(self, exception: Exception, number: int, command: Iterable[object]): + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: cmd = " ".join(map(safe_str, command)) msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" exception.args = (msg,) + exception.args[1:] @@ -4698,7 +4707,9 @@ class BitFieldOperation: Command builder for BITFIELD commands. """ - def __init__(self, client: Redis, key: KeyT, default_overflow: Optional[str] = None): + def __init__( + self, client: Redis, key: KeyT, default_overflow: Optional[str] = None + ): self.client = client self.key = key self._default_overflow = default_overflow diff --git a/aioredis/utils.py b/aioredis/utils.py index 730e5a729..1311e6f45 100644 --- a/aioredis/utils.py +++ b/aioredis/utils.py @@ -44,9 +44,13 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): @overload def str_if_bytes(value: bytes) -> str: # type: ignore[misc] ... + + @overload def str_if_bytes(value: _T) -> _T: ... + + def str_if_bytes(value: object) -> object: return ( value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value From 250e6ed163b38efcc961a3154c3080ec8c89a016 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 17:49:37 +0100 Subject: [PATCH 12/19] Imports --- aioredis/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioredis/utils.py b/aioredis/utils.py index 1311e6f45..8a1010e52 100644 --- a/aioredis/utils.py +++ b/aioredis/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union, TypeVar, overload +from typing import TYPE_CHECKING, TypeVar, overload if TYPE_CHECKING: from aioredis import Redis From 8aff9bb55b24fdc7b5ec77115e22e7c9898c1b2a Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 17:58:20 +0100 Subject: [PATCH 13/19] Simplify cmds. --- aioredis/client.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aioredis/client.py b/aioredis/client.py index 80c20d6c4..83209100b 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -5,7 +5,6 @@ import re import time as mod_time import warnings -from itertools import chain from typing import ( AbstractSet, Any, @@ -4480,9 +4479,9 @@ def pipeline_execute_command(self, *args, **options): async def _execute_transaction( self, connection: Connection, commands: CommandStackT, raise_on_error ): - cmds: Iterable[CommandT] = chain( - [(("MULTI",), {})], commands, [(("EXEC",), {})] - ) + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] ) From ae2dfa6597f7c1f62b40c248dadd63f16cbdb27d Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 18:01:50 +0100 Subject: [PATCH 14/19] Simplify cmds further. --- aioredis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioredis/client.py b/aioredis/client.py index 83209100b..6ea3f5746 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4483,7 +4483,7 @@ async def _execute_transaction( post: CommandT = (("EXEC",), {}) cmds = (pre, *commands, post) all_cmds = connection.pack_commands( - [args for args, options in cmds if EMPTY_RESPONSE not in options] + args for args, options in cmds if EMPTY_RESPONSE not in options ) await connection.send_packed_command(all_cmds) errors = [] From f613529eb92b3d8de9a7057e5f94cfe765b7035b Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 18:09:58 +0100 Subject: [PATCH 15/19] Give up --- aioredis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioredis/client.py b/aioredis/client.py index 6ea3f5746..2e53d8e54 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4476,7 +4476,7 @@ def pipeline_execute_command(self, *args, **options): self.command_stack.append((args, options)) return self - async def _execute_transaction( + async def _execute_transaction( # noqa: C901 self, connection: Connection, commands: CommandStackT, raise_on_error ): pre: CommandT = (("MULTI",), {}) From f5a8e80882c3dfa70cf9cc2e44d110faeb579d74 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 21 Aug 2021 20:35:18 +0100 Subject: [PATCH 16/19] Fix --- aioredis/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aioredis/client.py b/aioredis/client.py index 2e53d8e54..3d0ba7aee 100644 --- a/aioredis/client.py +++ b/aioredis/client.py @@ -4146,6 +4146,7 @@ def unsubscribe(self, *args) -> Awaitable: parsed_args = list_or_args(args[0], args[1:]) channels = self._normalize_keys(dict.fromkeys(parsed_args)) else: + parsed_args = [] channels = self.channels self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *parsed_args) From 792bfb026b012d46ca91813a63f0cdece5ced0ce Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 19 Sep 2021 23:46:21 +0100 Subject: [PATCH 17/19] Rename to _Sentinel. --- aioredis/connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index 0f41bac48..ee3d85033 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -83,11 +83,11 @@ SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." -class Sentinel(enum.Enum): +class _Sentinel(enum.Enum): sentinel = object() -SENTINEL = Sentinel.sentinel +SENTINEL = _Sentinel.sentinel MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." @@ -251,7 +251,7 @@ def length(self): async def _read_from_socket( self, length: Optional[int] = None, - timeout: Union[float, None, Sentinel] = SENTINEL, + timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ) -> bool: buf = self._buffer @@ -487,7 +487,7 @@ async def can_read(self, timeout: float): async def read_from_socket( self, - timeout: Union[float, None, Sentinel] = SENTINEL, + timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ): if self._stream is None or self._reader is None: From e78801be38a81578ecd05c814fa5204b3a861746 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 19 Sep 2021 23:47:56 +0100 Subject: [PATCH 18/19] Move self.encoder check. --- aioredis/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index ee3d85033..e6c74596e 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -394,7 +394,7 @@ async def can_read(self, timeout: float): return self._buffer and bool(await self._buffer.can_read(timeout)) async def read_response(self) -> Union[EncodableT, ResponseError, None]: - if not self._buffer: + if not self._buffer or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._buffer.readline() if not raw: @@ -437,7 +437,6 @@ async def read_response(self) -> Union[EncodableT, ResponseError, None]: return None response = [(await self.read_response()) for _ in range(length)] if isinstance(response, bytes): - assert self.encoder is not None response = self.encoder.decode(response) return response From 6180fa8cafb7c626c3414a4838124c3d8ae1b239 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 19 Sep 2021 23:50:47 +0100 Subject: [PATCH 19/19] More precise type for auth_args. --- aioredis/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aioredis/connection.py b/aioredis/connection.py index e6c74596e..93910971b 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -754,8 +754,9 @@ async def on_connect(self): # if username and/or password are set, authenticate if self.username or self.password: + auth_args: Union[Tuple[str], Tuple[str, str]] if self.username: - auth_args: Tuple[str, ...] = (self.username, self.password or "") + auth_args = (self.username, self.password or "") else: # Mypy bug: https://github.com/python/mypy/issues/10944 auth_args = (self.password or "",)