Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to expiring cache. #9730

Merged
merged 2 commits into from Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9730.misc
@@ -0,0 +1 @@
Add type hints to expiring cache.
2 changes: 1 addition & 1 deletion synapse/federation/federation_client.py
Expand Up @@ -102,7 +102,7 @@ def __init__(self, hs: "HomeServer"):
max_len=1000,
expiry_ms=120 * 1000,
reset_expiry_on_get=False,
)
) # type: ExpiringCache[str, EventBase]

def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/device.py
Expand Up @@ -631,7 +631,7 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
) # type: ExpiringCache[str, Set[str]]

# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
Expand Down Expand Up @@ -760,7 +760,7 @@ async def _need_to_do_resync(
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]

extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)

Expand Down
12 changes: 0 additions & 12 deletions synapse/handlers/e2e_keys.py
Expand Up @@ -38,7 +38,6 @@
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
Expand Down Expand Up @@ -1292,17 +1291,6 @@ def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
# user_id -> list of updates waiting to be handled.
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]

# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)

async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/sync.py
Expand Up @@ -252,13 +252,13 @@ def __init__(self, hs: "HomeServer"):
self.storage = hs.get_storage()
self.state_store = self.storage.state

# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]

async def wait_for_sync_for_user(
self,
Expand Down Expand Up @@ -733,8 +733,10 @@ async def compute_summary(

def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
) -> LruCache:
cache = self.lazy_loaded_members_cache.get(cache_key)
) -> LruCache[str, str]:
cache = self.lazy_loaded_members_cache.get(
cache_key
) # type: Optional[LruCache[str, str]]
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/media/v1/preview_url_resource.py
Expand Up @@ -175,7 +175,7 @@ def __init__(
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
)
) # type: ExpiringCache[str, ObservableDeferred]

if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
Expand Down
5 changes: 3 additions & 2 deletions synapse/state/__init__.py
Expand Up @@ -22,6 +22,7 @@
Callable,
DefaultDict,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -515,7 +516,7 @@ def __init__(self, hs):
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]

#
# stuff for tracking time spent on state-res by room
Expand All @@ -536,7 +537,7 @@ async def resolve_state_groups(
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> _StateCacheEntry:
"""Resolves conflicts between a set of state groups

Always generates a new state group (unless we hit the cache), so should
Expand Down
83 changes: 51 additions & 32 deletions synapse/util/caches/expiringcache.py
Expand Up @@ -15,40 +15,50 @@

import logging
from collections import OrderedDict
from typing import Any, Generic, Optional, TypeVar, Union, overload

import attr
from typing_extensions import Literal

from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.caches import register_cache

logger = logging.getLogger(__name__)


SENTINEL = object()
SENTINEL = object() # type: Any


T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")

class ExpiringCache:

class ExpiringCache(Generic[KT, VT]):
def __init__(
self,
cache_name,
clock,
max_len=0,
expiry_ms=0,
reset_expiry_on_get=False,
iterable=False,
cache_name: str,
clock: Clock,
max_len: int = 0,
expiry_ms: int = 0,
reset_expiry_on_get: bool = False,
iterable: bool = False,
):
"""
Args:
cache_name (str): Name of this cache, used for logging.
clock (Clock)
max_len (int): Max size of dict. If the dict grows larger than this
cache_name: Name of this cache, used for logging.
clock
max_len: Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
reset_expiry_on_get: If true, will reset the expiry time for
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
iterable: If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
Expand All @@ -62,7 +72,7 @@ def __init__(
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get

self._cache = OrderedDict()
self._cache = OrderedDict() # type: OrderedDict[KT, _CacheEntry]

self.iterable = iterable

Expand All @@ -79,12 +89,12 @@ def f():

self._clock.looping_call(f, self._expiry_ms / 2)

def __setitem__(self, key, value):
def __setitem__(self, key: KT, value: VT) -> None:
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
self.evict()

def evict(self):
def evict(self) -> None:
# Evict if there are now too many items
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
Expand All @@ -93,7 +103,7 @@ def evict(self):
else:
self.metrics.inc_evictions()

def __getitem__(self, key):
def __getitem__(self, key: KT) -> VT:
try:
entry = self._cache[key]
self.metrics.inc_hits()
Expand All @@ -106,7 +116,7 @@ def __getitem__(self, key):

return entry.value

def pop(self, key, default=SENTINEL):
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
"""Removes and returns the value with the given key from the cache.

If the key isn't in the cache then `default` will be returned if
Expand All @@ -115,29 +125,40 @@ def pop(self, key, default=SENTINEL):
Identical functionality to `dict.pop(..)`.
"""

value = self._cache.pop(key, default)
value = self._cache.pop(key, SENTINEL)
# The key was not found.
if value is SENTINEL:
raise KeyError(key)
if default is SENTINEL:
raise KeyError(key)
return default

return value
return value.value

def __contains__(self, key):
def __contains__(self, key: KT) -> bool:
return key in self._cache

def get(self, key, default=None):
@overload
def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
...

@overload
def get(self, key: KT, default: T) -> Union[VT, T]:
...

def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
try:
return self[key]
except KeyError:
return default

def setdefault(self, key, value):
def setdefault(self, key: KT, value: VT) -> VT:
try:
return self[key]
except KeyError:
self[key] = value
return value

def _prune_cache(self):
def _prune_cache(self) -> None:
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
Expand Down Expand Up @@ -166,7 +187,7 @@ def _prune_cache(self):
len(self),
)

def __len__(self):
def __len__(self) -> int:
if self.iterable:
return sum(len(entry.value) for entry in self._cache.values())
else:
Expand All @@ -190,9 +211,7 @@ def set_cache_factor(self, factor: float) -> bool:
return False


@attr.s(slots=True)
class _CacheEntry:
__slots__ = ["time", "value"]

def __init__(self, time, value):
self.time = time
self.value = value
time = attr.ib(type=int)
value = attr.ib()