diff --git a/changes/105.fix b/changes/105.fix new file mode 100644 index 00000000..9aaa410b --- /dev/null +++ b/changes/105.fix @@ -0,0 +1 @@ +Update mypy to 0.930 and fix newly discovered type errors diff --git a/setup.cfg b/setup.cfg index 0af12eda..e9ac4381 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,7 +83,7 @@ lint = flake8>=4.0.1 flake8-commas>=2.1 typecheck = - mypy>=0.920 + mypy>=0.930 types-python-dateutil types-toml types-setuptools diff --git a/src/ai/backend/common/etcd.py b/src/ai/backend/common/etcd.py index c6445a2d..225e579b 100644 --- a/src/ai/backend/common/etcd.py +++ b/src/ai/backend/common/etcd.py @@ -7,6 +7,8 @@ using callbacks in separate threads. ''' +from __future__ import annotations + import asyncio from collections import namedtuple, ChainMap from concurrent.futures import ThreadPoolExecutor @@ -15,10 +17,22 @@ import logging import time from typing import ( - Any, Awaitable, Callable, Iterable, Optional, Union, AsyncGenerator, - Dict, Mapping, + Awaitable, + Callable, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, Tuple, + TypeVar, + Union, + cast, +) +from typing_extensions import ( # FIXME: move to typing when we migrate to Python 3.10 + Concatenate, + ParamSpec, ) from urllib.parse import quote as _quote, unquote @@ -101,9 +115,18 @@ async def reauthenticate(etcd_sync, creds, executor): EtcdTokenCallCredentials(resp.token)) -def reconn_reauth_adaptor(meth: Callable[..., Awaitable[Any]]): +P = ParamSpec("P") +R = TypeVar("R") + + +# FIXME: when mypy begins to support typing.Concatenate, remove "type: ignore" comments +# (ref: https://github.com/python/mypy/issues/8645) +def reconn_reauth_adaptor( + meth: Callable[Concatenate[AsyncEtcd, P], Awaitable[R]], # type: ignore +) -> Callable[Concatenate[AsyncEtcd, P], Awaitable[R]]: # type: ignore + @functools.wraps(meth) - async def wrapped(self, *args, **kwargs): + async def wrapped(self: AsyncEtcd, *args: P.args, **kwargs: P.kwargs) -> R: num_reauth_tries = 0 num_reconn_tries = 0 while True: @@ -131,14 +154,21 @@ async def wrapped(self, *args, **kwargs): continue else: raise + return wrapped class AsyncEtcd: - def __init__(self, addr: HostPortPair, namespace: str, - scope_prefix_map: Mapping[ConfigScopes, str], *, - credentials=None, encoding='utf8'): + def __init__( + self, + addr: HostPortPair, + namespace: str, + scope_prefix_map: Mapping[ConfigScopes, str], + *, + credentials=None, + encoding='utf8', + ) -> None: self.scope_prefix_map = t.Dict({ t.Key(ConfigScopes.GLOBAL): t.String(allow_blank=True), t.Key(ConfigScopes.SGROUP, optional=True): t.String, @@ -188,10 +218,26 @@ def _demangle_key(self, k: Union[bytes, str]) -> str: k = k[len(prefix):] return k + def _merge_scope_prefix_map( + self, + override: Mapping[ConfigScopes, str] = None, + ) -> Mapping[ConfigScopes, str]: + """ + This stub ensures immutable usage of the ChainMap because ChainMap does *not* + have the immutable version in typeshed. + (ref: https://github.com/python/typeshed/issues/6042) + """ + return ChainMap(cast(MutableMapping, override) or {}, self.scope_prefix_map) + @reconn_reauth_adaptor - async def put(self, key: str, val: str, *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): + async def put( + self, + key: str, + val: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): """ Put a single key-value pair to the etcd. @@ -201,17 +247,21 @@ async def put(self, key: str, val: str, *, :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. :return: """ - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') return await self.loop.run_in_executor( self.executor, lambda: self.etcd_sync.put(mangled_key, str(val).encode(self.encoding))) @reconn_reauth_adaptor - async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): + async def put_prefix( + self, + key: str, + dict_obj: Mapping[str, str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): """ Put a nested dict object under the given key prefix. All keys in the dict object are automatically quoted to avoid conflicts with the path separator. @@ -222,8 +272,7 @@ async def put_prefix(self, key: str, dict_obj: Mapping[str, str], *, :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. :return: """ - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] flattened_dict: Dict[str, str] = {} def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None: @@ -250,9 +299,13 @@ def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None: )) @reconn_reauth_adaptor - async def put_dict(self, dict_obj: Mapping[str, str], *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): + async def put_dict( + self, + dict_obj: Mapping[str, str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): """ Put a flattened key-value pairs into the etcd. Since the given dict must be a flattened one, its keys must be quoted as needed by the caller. @@ -263,8 +316,7 @@ async def put_dict(self, dict_obj: Mapping[str, str], *, :param scope_prefix_map: The scope map used to mangle the prefix for the config scope. :return: """ - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] return await self.loop.run_in_executor( self.executor, lambda: self.etcd_sync.transaction( @@ -276,10 +328,13 @@ async def put_dict(self, dict_obj: Mapping[str, str], *, )) @reconn_reauth_adaptor - async def get(self, key: str, *, - scope: ConfigScopes = ConfigScopes.MERGED, - scope_prefix_map: Mapping[ConfigScopes, str] = None) \ - -> Optional[str]: + async def get( + self, + key: str, + *, + scope: ConfigScopes = ConfigScopes.MERGED, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> Optional[str]: """ Get a single key from the etcd. Returns ``None`` if the key does not exist. @@ -298,22 +353,22 @@ async def get_impl(key: str) -> Optional[str]: lambda: self.etcd_sync.get(mangled_key)) return val.decode(self.encoding) if val is not None else None - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) + _scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map) if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] - p = scope_prefix_map.get(ConfigScopes.SGROUP) + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) if p is not None: scope_prefixes.insert(0, p) - p = scope_prefix_map.get(ConfigScopes.NODE) + p = _scope_prefix_map.get(ConfigScopes.NODE) if p is not None: scope_prefixes.insert(0, p) elif scope == ConfigScopes.SGROUP: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] - p = scope_prefix_map.get(ConfigScopes.SGROUP) + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) if p is not None: scope_prefixes.insert(0, p) elif scope == ConfigScopes.GLOBAL: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] else: raise ValueError('Invalid scope prefix value') values = await asyncio.gather(*[ @@ -328,10 +383,13 @@ async def get_impl(key: str) -> Optional[str]: return value @reconn_reauth_adaptor - async def get_prefix(self, key_prefix: str, - scope: ConfigScopes = ConfigScopes.MERGED, - scope_prefix_map: Mapping[ConfigScopes, str] = None) \ - -> Mapping[str, Optional[str]]: + async def get_prefix( + self, + key_prefix: str, + *, + scope: ConfigScopes = ConfigScopes.MERGED, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> Mapping[str, Optional[str]]: """ Retrieves all key-value pairs under the given key prefix as a nested dictionary. All dictionary keys are automatically unquoted. @@ -376,22 +434,22 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]: t[0].decode(self.encoding)) for t in results) - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) + _scope_prefix_map = self._merge_scope_prefix_map(scope_prefix_map) if scope == ConfigScopes.MERGED or scope == ConfigScopes.NODE: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] - p = scope_prefix_map.get(ConfigScopes.SGROUP) + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) if p is not None: scope_prefixes.insert(0, p) - p = scope_prefix_map.get(ConfigScopes.NODE) + p = _scope_prefix_map.get(ConfigScopes.NODE) if p is not None: scope_prefixes.insert(0, p) elif scope == ConfigScopes.SGROUP: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] - p = scope_prefix_map.get(ConfigScopes.SGROUP) + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] + p = _scope_prefix_map.get(ConfigScopes.SGROUP) if p is not None: scope_prefixes.insert(0, p) elif scope == ConfigScopes.GLOBAL: - scope_prefixes = [scope_prefix_map[ConfigScopes.GLOBAL]] + scope_prefixes = [_scope_prefix_map[ConfigScopes.GLOBAL]] else: raise ValueError('Invalid scope prefix value') pair_sets = await asyncio.gather(*[ @@ -408,11 +466,16 @@ async def get_prefix_impl(key_prefix: str) -> Iterable[Tuple[str, str]]: get_prefix_dict = get_prefix @reconn_reauth_adaptor - async def replace(self, key: str, initial_val: str, new_val: str, *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None) -> bool: - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + async def replace( + self, + key: str, + initial_val: str, + new_val: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ) -> bool: + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') success = await self.loop.run_in_executor( self.executor, @@ -420,22 +483,28 @@ async def replace(self, key: str, initial_val: str, new_val: str, *, return success @reconn_reauth_adaptor - async def delete(self, key: str, *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + async def delete( + self, + key: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') return await self.loop.run_in_executor( self.executor, lambda: self.etcd_sync.delete(mangled_key)) @reconn_reauth_adaptor - async def delete_multi(self, keys: Iterable[str], *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + async def delete_multi( + self, + keys: Iterable[str], + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] return await self.loop.run_in_executor( self.executor, lambda: self.etcd_sync.transaction( @@ -446,17 +515,24 @@ async def delete_multi(self, keys: Iterable[str], *, )) @reconn_reauth_adaptor - async def delete_prefix(self, key_prefix: str, *, - scope: ConfigScopes = ConfigScopes.GLOBAL, - scope_prefix_map: Mapping[ConfigScopes, str] = None): - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + async def delete_prefix( + self, + key_prefix: str, + *, + scope: ConfigScopes = ConfigScopes.GLOBAL, + scope_prefix_map: Mapping[ConfigScopes, str] = None, + ): + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}') return await self.loop.run_in_executor( self.executor, lambda: self.etcd_sync.delete_prefix(mangled_key_prefix)) - def _watch_cb(self, queue: asyncio.Queue, resp: etcd3.watch.WatchResponse) -> None: + def _watch_cb( + self, + queue: asyncio.Queue, + resp: etcd3.watch.WatchResponse, + ) -> None: if isinstance(resp, grpc.RpcError): if ( resp.code() == grpc.StatusCode.UNAVAILABLE or @@ -543,8 +619,7 @@ async def watch( cleanup_event: asyncio.Event = None, wait_timeout: float = None, ) -> AsyncGenerator[Union[QueueSentinel, Event], None]: - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] scope_prefix_len = len(f'{_slash(scope_prefix)}') mangled_key = self._mangle_key(f'{_slash(scope_prefix)}{key}') # NOTE: yield from in async-generator is not supported. @@ -582,8 +657,7 @@ async def watch_prefix( cleanup_event: asyncio.Event = None, wait_timeout: float = None, ) -> AsyncGenerator[Union[QueueSentinel, Event], None]: - scope_prefix_map = ChainMap(scope_prefix_map or {}, self.scope_prefix_map) - scope_prefix = scope_prefix_map[scope] + scope_prefix = self._merge_scope_prefix_map(scope_prefix_map)[scope] scope_prefix_len = len(f'{_slash(scope_prefix)}') mangled_key_prefix = self._mangle_key(f'{_slash(scope_prefix)}{key_prefix}') while True: diff --git a/src/ai/backend/common/utils.py b/src/ai/backend/common/utils.py index f2118fa4..8942fb12 100644 --- a/src/ai/backend/common/utils.py +++ b/src/ai/backend/common/utils.py @@ -222,12 +222,7 @@ def __and__(self, other): return self.value == other raise TypeError - def __rand__(self, other): - if isinstance(other, (set, frozenset)): - return self.value in other - if isinstance(other, str): - return self.value == other - raise TypeError + __rand__ = __and__ def __xor__(self, other): if isinstance(other, (set, frozenset)): diff --git a/tests/test_utils.py b/tests/test_utils.py index 24b5e4c3..7907c518 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,22 +26,22 @@ ) -def test_odict(): +def test_odict() -> None: assert odict(('a', 1), ('b', 2)) == OrderedDict([('a', 1), ('b', 2)]) -def test_dict2kvlist(): +def test_dict2kvlist() -> None: ret = list(dict2kvlist({'a': 1, 'b': 2})) assert set(ret) == {'a', 1, 'b', 2} -def test_generate_uuid(): +def test_generate_uuid() -> None: u = generate_uuid() assert len(u) == 22 assert isinstance(u, str) -def test_random_seq(): +def test_random_seq() -> None: assert [*get_random_seq(10, 11, 1)] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] assert [*get_random_seq(10, 6, 2)] == [0, 2, 4, 6, 8, 10] with pytest.raises(AssertionError): @@ -57,7 +57,7 @@ def test_random_seq(): assert x > last_x + 1 -def test_nmget(): +def test_nmget() -> None: o = {'a': {'b': 1}, 'x': None} assert nmget(o, 'a', 0) == {'b': 1} assert nmget(o, 'a.b', 0) == 1 @@ -68,7 +68,7 @@ def test_nmget(): assert nmget(o, 'x', 0, null_as_default=False) is None -def test_readable_size_to_bytes(): +def test_readable_size_to_bytes() -> None: assert readable_size_to_bytes(2) == 2 assert readable_size_to_bytes('2') == 2 assert readable_size_to_bytes('2K') == 2 * (2 ** 10) @@ -93,7 +93,7 @@ def test_readable_size_to_bytes(): readable_size_to_bytes('TT') -def test_str_to_timedelta(): +def test_str_to_timedelta() -> None: assert str_to_timedelta('1d2h3m4s') == timedelta(days=1, hours=2, minutes=3, seconds=4) assert str_to_timedelta('1d2h3m') == timedelta(days=1, hours=2, minutes=3) assert str_to_timedelta('1d2h') == timedelta(days=1, hours=2) @@ -128,7 +128,7 @@ def test_str_to_timedelta(): @pytest.mark.asyncio -async def test_curl_returns_stripped_body(mocker): +async def test_curl_returns_stripped_body(mocker) -> None: mock_get = mocker.patch.object(aiohttp.ClientSession, 'get') mock_resp = {'status': 200, 'text': mock_corofunc(b'success ')} mock_get.return_value = AsyncContextManagerMock(**mock_resp) @@ -140,7 +140,7 @@ async def test_curl_returns_stripped_body(mocker): @pytest.mark.asyncio -async def test_curl_returns_default_value_if_not_success(mocker): +async def test_curl_returns_default_value_if_not_success(mocker) -> None: mock_get = mocker.patch.object(aiohttp.ClientSession, 'get') mock_resp = {'status': 400, 'text': mock_corofunc(b'bad request')} mock_get.return_value = AsyncContextManagerMock(**mock_resp) @@ -154,9 +154,11 @@ async def test_curl_returns_default_value_if_not_success(mocker): assert resp == 'default' -def test_string_set_flag(): +def test_string_set_flag() -> None: - class MyFlags(StringSetFlag): + # FIXME: Remove "type: ignore" when mypy gets released with + # python/mypy#11579. + class MyFlags(StringSetFlag): # type: ignore A = 'a' B = 'b' @@ -193,14 +195,14 @@ class MyFlags(StringSetFlag): class TestAsyncBarrier: - def test_async_barrier_initialization(self): + def test_async_barrier_initialization(self) -> None: barrier = AsyncBarrier(num_parties=5) assert barrier.num_parties == 5 assert barrier.cond is not None # default condition @pytest.mark.asyncio - async def test_wait_notify_all_if_cound_eq_num_parties(self, mocker): + async def test_wait_notify_all_if_cound_eq_num_parties(self, mocker) -> None: mock_cond = mocker.patch.object(asyncio, 'Condition') mock_resp = { 'notify_all': mock.Mock(), @@ -214,8 +216,9 @@ async def test_wait_notify_all_if_cound_eq_num_parties(self, mocker): await barrier.wait() assert barrier.count == 1 - mock_cond.return_value.notify_all.assert_called_once_with() - mock_cond.return_value.wait.assert_not_called() + # The methods are added at runtime. + mock_cond.return_value.notify_all.assert_called_once_with() # type: ignore + mock_cond.return_value.wait.assert_not_called() # type: ignore def test_async_barrier_reset(self): barrier = AsyncBarrier(num_parties=5) @@ -227,7 +230,7 @@ def test_async_barrier_reset(self): @pytest.mark.asyncio -async def test_run_through(): +async def test_run_through() -> None: i = 0 @@ -270,7 +273,7 @@ def do_sync(): @pytest.mark.asyncio -async def test_async_file_writer_str(): +async def test_async_file_writer_str() -> None: # 1. Get temporary filename with NamedTemporaryFile() as temp_file: file_name = temp_file.name @@ -298,7 +301,7 @@ async def test_async_file_writer_str(): @pytest.mark.asyncio -async def test_async_file_writer_bytes(): +async def test_async_file_writer_bytes() -> None: # 1. Get temporary filename with NamedTemporaryFile() as temp_file: file_name = temp_file.name