Skip to content

Commit

Permalink
Update mypy to 0.930 (#105)
Browse files Browse the repository at this point in the history
* setup: Update mypy to 0.930

* fix: Modernize common.etcd coding style and fix type errors

  - While pyright (VSCode) supports ParamSpec and Concatenate,
    mypy 0.930 does not support Concatenate yet.
    Temporarily comment related lines with "type: ignore".

  - Introduce a wrapper function to ensure immutable usage of
    ChainMap which is defined as mutable in typeshed, with minimal
    runtime overheads.

* test: Let test_utils type-checked and workaround python/mypy#11850
  • Loading branch information
achimnol committed Dec 27, 2021
1 parent 686d6b7 commit 9c4ee94
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 94 deletions.
1 change: 1 addition & 0 deletions changes/105.fix
@@ -0,0 +1 @@
Update mypy to 0.930 and fix newly discovered type errors
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -83,7 +83,7 @@ lint =
flake8>=4.0.1
flake8-commas>=2.1
typecheck =
mypy>=0.920
mypy>=0.930
types-python-dateutil
types-toml
types-setuptools
Expand Down
212 changes: 143 additions & 69 deletions src/ai/backend/common/etcd.py
Expand Up @@ -7,6 +7,8 @@
using callbacks in separate threads.
'''

from __future__ import annotations

import asyncio
from collections import namedtuple, ChainMap
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -15,10 +17,22 @@
import logging
import time
from typing import (
Any, Awaitable, Callable, Iterable, Optional, Union,
AsyncGenerator,
Dict, Mapping,
Awaitable,
Callable,
Dict,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from typing_extensions import ( # FIXME: move to typing when we migrate to Python 3.10
Concatenate,
ParamSpec,
)
from urllib.parse import quote as _quote, unquote

Expand Down Expand Up @@ -101,9 +115,18 @@ async def reauthenticate(etcd_sync, creds, executor):
EtcdTokenCallCredentials(resp.token))


def reconn_reauth_adaptor(meth: Callable[..., Awaitable[Any]]):
P = ParamSpec("P")
R = TypeVar("R")


# FIXME: when mypy begins to support typing.Concatenate, remove "type: ignore" comments
# (ref: https://github.com/python/mypy/issues/8645)
def reconn_reauth_adaptor(
meth: Callable[Concatenate[AsyncEtcd, P], Awaitable[R]], # type: ignore
) -> Callable[Concatenate[AsyncEtcd, P], Awaitable[R]]: # type: ignore

@functools.wraps(meth)
async def wrapped(self, *args, **kwargs):
async def wrapped(self: AsyncEtcd, *args: P.args, **kwargs: P.kwargs) -> R:
num_reauth_tries = 0
num_reconn_tries = 0
while True:
Expand Down Expand Up @@ -131,14 +154,21 @@ async def wrapped(self, *args, **kwargs):
continue
else:
raise

return wrapped


class AsyncEtcd:

def __init__(self, addr: HostPortPair, namespace: str,
scope_prefix_map: Mapping[ConfigScopes, str], *,
credentials=None, encoding='utf8'):
def __init__(
self,
addr: HostPortPair,
namespace: str,
scope_prefix_map: Mapping[ConfigScopes, str],
*,
credentials=None,
encoding='utf8',
) -> None:
self.scope_prefix_map = t.Dict({
t.Key(ConfigScopes.GLOBAL): t.String(allow_blank=True),
t.Key(ConfigScopes.SGROUP, optional=True): t.String,
Expand Down Expand Up @@ -188,10 +218,26 @@ def _demangle_key(self, k: Union[bytes, str]) -> str:
k = k[len(prefix):]
return k

def _merge_scope_prefix_map(
self,
override: Mapping[ConfigScopes, str] = None,
) -> Mapping[ConfigScopes, str]:
"""
This stub ensures immutable usage of the ChainMap because ChainMap does *not*
have the immutable version in typeshed.
(ref: https://github.com/python/typeshed/issues/6042)
"""
return ChainMap(cast(MutableMapping, override) or {}, self.scope_prefix_map)

@reconn_reauth_adaptor
async def put(self, key: str, val: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put(
self,
key: str,
val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a single key-value pair to the etcd.
Expand All @@ -201,17 +247,21 @@ async def put(self, key: str, val: str, *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.put(mangled_key, str(val).encode(self.encoding)))

@reconn_reauth_adaptor
async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put_prefix(
self,
key: str,
dict_obj: Mapping[str, str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a nested dict object under the given key prefix.
All keys in the dict object are automatically quoted to avoid conflicts with the path separator.
Expand All @@ -222,8 +272,7 @@ async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
flattened_dict: Dict[str, str] = {}

def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None:
Expand All @@ -250,9 +299,13 @@ def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None:
))

@reconn_reauth_adaptor
async def put_dict(self, dict_obj: Mapping[str, str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
async def put_dict(
self,
dict_obj: Mapping[str, str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
"""
Put a flattened key-value pairs into the etcd.
Since the given dict must be a flattened one, its keys must be quoted as needed by the caller.
Expand All @@ -263,8 +316,7 @@ async def put_dict(self, dict_obj: Mapping[str, str], *,
:param scope_prefix_map: The scope map used to mangle the prefix for the config scope.
:return:
"""
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.transaction(
Expand All @@ -276,10 +328,13 @@ async def put_dict(self, dict_obj: Mapping[str, str], *,
))

@reconn_reauth_adaptor
async def get(self, key: str, *,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None) \
-> Optional[str]:
async def get(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> Optional[str]:
"""
Get a single key from the etcd.
Returns ``None`` if the key does not exist.
Expand All @@ -298,22 +353,22 @@ async def get_impl(key: str) -> Optional[str]:
lambda: self.etcd_sync.get(mangled_key))
return val.decode(self.encoding) if val is not None else None

scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
_scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map)
if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
p = scope_prefix_map.get(ConfigScopes.NODE)
p = _scope_prefix_map.get(ConfigScopes.NODE)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.SGROUP:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.GLOBAL:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
else:
raise ValueError('Invalid scope prefix value')
values = await asyncio.gather(*[
Expand All @@ -328,10 +383,13 @@ async def get_impl(key: str) -> Optional[str]:
return value

@reconn_reauth_adaptor
async def get_prefix(self, key_prefix: str,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None) \
-> Mapping[str, Optional[str]]:
async def get_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.MERGED,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> Mapping[str, Optional[str]]:
"""
Retrieves all key-value pairs under the given key prefix as a nested dictionary.
All dictionary keys are automatically unquoted.
Expand Down Expand Up @@ -376,22 +434,22 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]:
t[0].decode(self.encoding))
for t in results)

scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
_scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map)
if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
p = scope_prefix_map.get(ConfigScopes.NODE)
p = _scope_prefix_map.get(ConfigScopes.NODE)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.SGROUP:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
p = scope_prefix_map.get(ConfigScopes.SGROUP)
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
p = _scope_prefix_map.get(ConfigScopes.SGROUP)
if p is not None:
scope_prefixes.insert(0, p)
elif scope == ConfigScopes.GLOBAL:
scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]]
scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]]
else:
raise ValueError('Invalid scope prefix value')
pair_sets = await asyncio.gather(*[
Expand All @@ -408,34 +466,45 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]:
get_prefix_dict = get_prefix

@reconn_reauth_adaptor
async def replace(self, key: str, initial_val: str, new_val: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None) -> bool:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def replace(
self,
key: str,
initial_val: str,
new_val: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
) -> bool:
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
success = await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.replace(mangled_key, initial_val, new_val))
return success

@reconn_reauth_adaptor
async def delete(self, key: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete(
self,
key: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.delete(mangled_key))

@reconn_reauth_adaptor
async def delete_multi(self, keys: Iterable[str], *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete_multi(
self,
keys: Iterable[str],
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.transaction(
Expand All @@ -446,17 +515,24 @@ async def delete_multi(self, keys: Iterable[str], *,
))

@reconn_reauth_adaptor
async def delete_prefix(self, key_prefix: str, *,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None):
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
async def delete_prefix(
self,
key_prefix: str,
*,
scope: ConfigScopes = ConfigScopes.GLOBAL,
scope_prefix_map: Mapping[ConfigScopes, str] = None,
):
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}')
return await self.loop.run_in_executor(
self.executor,
lambda: self.etcd_sync.delete_prefix(mangled_key_prefix))

def _watch_cb(self, queue: asyncio.Queue, resp: etcd3.watch.WatchResponse) -> None:
def _watch_cb(
self,
queue: asyncio.Queue,
resp: etcd3.watch.WatchResponse,
) -> None:
if isinstance(resp, grpc.RpcError):
if (
resp.code() == grpc.StatusCode.UNAVAILABLE or
Expand Down Expand Up @@ -543,8 +619,7 @@ async def watch(
cleanup_event: asyncio.Event = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
scope_prefix_len = len(f'{_slash(scope_prefix)}')
mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}')
# NOTE: yield from in async-generator is not supported.
Expand Down Expand Up @@ -582,8 +657,7 @@ async def watch_prefix(
cleanup_event: asyncio.Event = None,
wait_timeout: float = None,
) -> AsyncGenerator[Union[QueueSentinel, Event], None]:
scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map)
scope_prefix = scope_prefix_map[scope]
scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope]
scope_prefix_len = len(f'{_slash(scope_prefix)}')
mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}')
while True:
Expand Down

0 comments on commit 9c4ee94

Please sign in to comment.