Skip to content

Commit

Permalink
Prevent cached callables being invoked multiple times when multithrea…
Browse files Browse the repository at this point in the history
…ding
  • Loading branch information
Jack Brown committed Oct 5, 2021
1 parent 8b9bbc4 commit 220bfab
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 42 deletions.
70 changes: 48 additions & 22 deletions src/cachetools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import random
import time
from threading import RLock

from .keys import hashkey

Expand Down Expand Up @@ -526,20 +527,33 @@ def wrapper(*args, **kwargs):

else:

_key_level_locks = {}

def wrapper(*args, **kwargs):
k = key(*args, **kwargs)
try:
with lock:
with lock:
try:
return cache[k]
except KeyError:
pass # key not found
v = func(*args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large
except KeyError:
pass # key not found
_key_level_lock = _key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
with lock:
# In the case that this thread was blocked, retrieve and return the now computed value
return cache[k]
except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = func(*args, **kwargs)
finally:
with lock:
_key_level_locks.pop(k)
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large

return functools.update_wrapper(wrapper, func)

Expand Down Expand Up @@ -572,24 +586,36 @@ def wrapper(self, *args, **kwargs):
return v

else:
_key_level_locks = {}

def wrapper(self, *args, **kwargs):
c = cache(self)
if c is None:
return method(self, *args, **kwargs)
k = key(*args, **kwargs)
try:
with lock(self):
with lock(self):
try:
return c[k]
except KeyError:
pass # key not found
v = method(self, *args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock(self):
return c.setdefault(k, v)
except ValueError:
return v # value too large
except KeyError:
pass
_key_level_lock = _key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
with lock(self):
# In the case that this thread was blocked, retrieve and return the now computed value
return c[k]
except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = method(self, *args, **kwargs)
finally:
with lock(self):
_key_level_locks.pop(k)
try:
with lock(self):
return c.setdefault(k, v)
except ValueError:
return v # value too large

return functools.update_wrapper(wrapper, method)

Expand Down
24 changes: 17 additions & 7 deletions src/cachetools/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _cache(cache, typed):
def decorator(func):
key = keys.typedkey if typed else keys.hashkey
lock = RLock()
key_level_locks = {}
stats = [0, 0]

def wrapper(*args, **kwargs):
Expand All @@ -58,13 +59,22 @@ def wrapper(*args, **kwargs):
return v
except KeyError:
stats[1] += 1
v = func(*args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large
_key_level_lock = key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
# In the case that this thread was blocked, retrieve and return the now computed value
return cache[k]
except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = func(*args, **kwargs)
finally:
with lock:
key_level_locks.pop(k)
try:
return cache.setdefault(k,v)
except ValueError:
return v # value too large

def cache_info():
with lock:
Expand Down
18 changes: 9 additions & 9 deletions tests/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def __add__(self, other):
def test_locked_dict(self):
cached = Locked({})

self.assertEqual(cached.get(0), 1)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(1.0), 3)
self.assertEqual(cached.get(2.0), 7)
self.assertEqual(cached.get(0), 2)
self.assertEqual(cached.get(1), 5)
self.assertEqual(cached.get(1), 5)
self.assertEqual(cached.get(1.0), 5)
self.assertEqual(cached.get(2.0), 10)

def test_locked_nocache(self):
cached = Locked(None)
Expand All @@ -158,8 +158,8 @@ def test_locked_nocache(self):
def test_locked_nospace(self):
cached = Locked(LRUCache(maxsize=0))

self.assertEqual(cached.get(0), 1)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(0), 2)
self.assertEqual(cached.get(1), 5)
self.assertEqual(cached.get(1.0), 7)
self.assertEqual(cached.get(1.0), 9)
self.assertEqual(cached.get(1), 8)
self.assertEqual(cached.get(1.0), 11)
self.assertEqual(cached.get(1.0), 14)
38 changes: 34 additions & 4 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import time
import unittest
from concurrent.futures import ThreadPoolExecutor

import cachetools
import cachetools.keys
Expand Down Expand Up @@ -92,11 +94,11 @@ def __exit__(self, *exc):
self.assertEqual(len(cache), 0)
self.assertEqual(wrapper.__wrapped__, self.func)
self.assertEqual(wrapper(0), 0)
self.assertEqual(Lock.count, 2)
self.assertEqual(Lock.count, 3)
self.assertEqual(wrapper(1), 1)
self.assertEqual(Lock.count, 4)
self.assertEqual(Lock.count, 6)
self.assertEqual(wrapper(1), 1)
self.assertEqual(Lock.count, 5)
self.assertEqual(Lock.count, 7)


class CacheWrapperTest(unittest.TestCase, DecoratorTestMixin):
Expand Down Expand Up @@ -132,7 +134,35 @@ def __exit__(self, *exc):

self.assertEqual(wrapper(0), 0)
self.assertEqual(len(cache), 0)
self.assertEqual(Lock.count, 2)
self.assertEqual(Lock.count, 3) # Initial miss, key level miss, setdefault

def test_doesnt_execute_multiple_times_when_multithreading(self):
class Lock:

count = 0

def __enter__(self):
Lock.count += 1

def __exit__(self, *exc):
pass
def _long_func(*args, **kwargs):
time.sleep(1)
return self.func(*args, **kwargs)

cache = self.cache(5)
wrapper = cachetools.cached(cache, lock=Lock())(_long_func)

self.assertEqual(len(cache), 0)
self.assertEqual(wrapper.__wrapped__, _long_func)
with ThreadPoolExecutor(max_workers=5) as executor:
executor.map(wrapper, [1] * 10)
# only called the wrapped function once
self.assertEqual(self.func(), 1)
# Accessed cache under lock 11 times
self.assertEqual(Lock.count, 16) # 10x top level, 5x key-lvl, 1x setdefault
# all of our arguments were the same (1)
self.assertEqual(len(cache), 1)


class DictWrapperTest(unittest.TestCase, DecoratorTestMixin):
Expand Down

0 comments on commit 220bfab

Please sign in to comment.