Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key-level locking which corrects multithreading performance #224

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
70 changes: 48 additions & 22 deletions src/cachetools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import random
import time
from threading import RLock

from .keys import hashkey

Expand Down Expand Up @@ -526,20 +527,33 @@ def wrapper(*args, **kwargs):

else:

_key_level_locks = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to use defaultdict to not allocate extra RLock?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! That datatype looks ideal here. Out of curiosity, where is the extra RLock being allocated here?

Copy link

@horpto horpto Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not here actually, but a pair lines below on _key_level_lock = _key_level_locks.setdefault(k, RLock()). RLock is always created even it's not needed. As that line is under the lock so it must be atomic.


def wrapper(*args, **kwargs):
k = key(*args, **kwargs)
try:
with lock:
with lock:
try:
return cache[k]
except KeyError:
pass # key not found
v = func(*args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large
except KeyError:
pass # key not found
_key_level_lock = _key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
with lock:
# In the case that this thread was blocked, retrieve and return the now computed value
return cache[k]
except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = func(*args, **kwargs)
finally:
with lock:
_key_level_locks.pop(k)
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large

return functools.update_wrapper(wrapper, func)

Expand Down Expand Up @@ -572,24 +586,36 @@ def wrapper(self, *args, **kwargs):
return v

else:
_key_level_locks = {}

def wrapper(self, *args, **kwargs):
c = cache(self)
if c is None:
return method(self, *args, **kwargs)
k = key(*args, **kwargs)
try:
with lock(self):
with lock(self):
try:
return c[k]
except KeyError:
pass # key not found
v = method(self, *args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock(self):
return c.setdefault(k, v)
except ValueError:
return v # value too large
except KeyError:
pass
_key_level_lock = _key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
with lock(self):
# In the case that this thread was blocked, retrieve and return the now computed value
return c[k]
except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = method(self, *args, **kwargs)
finally:
with lock(self):
_key_level_locks.pop(k)
try:
with lock(self):
return c.setdefault(k, v)
except ValueError:
return v # value too large

return functools.update_wrapper(wrapper, method)

Expand Down
24 changes: 17 additions & 7 deletions src/cachetools/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _cache(cache, typed):
def decorator(func):
key = keys.typedkey if typed else keys.hashkey
lock = RLock()
key_level_locks = {}
stats = [0, 0]

def wrapper(*args, **kwargs):
Expand All @@ -58,13 +59,22 @@ def wrapper(*args, **kwargs):
return v
except KeyError:
stats[1] += 1
v = func(*args, **kwargs)
# in case of a race, prefer the item already in the cache
try:
with lock:
return cache.setdefault(k, v)
except ValueError:
return v # value too large
_key_level_lock = key_level_locks.setdefault(k, RLock())
# Only one caller may generate or retrieve a value for this key
with _key_level_lock:
try:
# In the case that this thread was blocked, retrieve and return the now computed value
return cache[k]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

except KeyError:
# Otherwise compute it on this thread since the key is not in the cache
v = func(*args, **kwargs)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception may cause memory leak

finally:
with lock:
key_level_locks.pop(k)
try:
return cache.setdefault(k,v)
except ValueError:
return v # value too large

def cache_info():
with lock:
Expand Down
20 changes: 10 additions & 10 deletions tests/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def __add__(self, other):
def test_locked_dict(self):
cached = Locked({})

self.assertEqual(cached.get(0), 1)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(1.0), 3)
self.assertEqual(cached.get(2.0), 7)
self.assertEqual(cached.get(0), 2)
self.assertEqual(cached.get(1), 6)
self.assertEqual(cached.get(1), 6)
self.assertEqual(cached.get(1.0), 6)
self.assertEqual(cached.get(2.0), 12)

def test_locked_nocache(self):
cached = Locked(None)
Expand All @@ -158,8 +158,8 @@ def test_locked_nocache(self):
def test_locked_nospace(self):
cached = Locked(LRUCache(maxsize=0))

self.assertEqual(cached.get(0), 1)
self.assertEqual(cached.get(1), 3)
self.assertEqual(cached.get(1), 5)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICS you should _not_need to change existing unit tests.

self.assertEqual(cached.get(1.0), 7)
self.assertEqual(cached.get(1.0), 9)
self.assertEqual(cached.get(0), 2)
self.assertEqual(cached.get(1), 6)
self.assertEqual(cached.get(1), 10)
self.assertEqual(cached.get(1.0), 14)
self.assertEqual(cached.get(1.0), 18)
38 changes: 34 additions & 4 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import time
import unittest
from concurrent.futures import ThreadPoolExecutor

import cachetools
import cachetools.keys
Expand Down Expand Up @@ -92,11 +94,11 @@ def __exit__(self, *exc):
self.assertEqual(len(cache), 0)
self.assertEqual(wrapper.__wrapped__, self.func)
self.assertEqual(wrapper(0), 0)
self.assertEqual(Lock.count, 2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above

self.assertEqual(wrapper(1), 1)
self.assertEqual(Lock.count, 4)
self.assertEqual(wrapper(1), 1)
self.assertEqual(Lock.count, 5)
self.assertEqual(Lock.count, 8)
self.assertEqual(wrapper(1), 1)
self.assertEqual(Lock.count, 9)


class CacheWrapperTest(unittest.TestCase, DecoratorTestMixin):
Expand Down Expand Up @@ -132,7 +134,35 @@ def __exit__(self, *exc):

self.assertEqual(wrapper(0), 0)
self.assertEqual(len(cache), 0)
self.assertEqual(Lock.count, 2)
self.assertEqual(Lock.count, 4) # Initial miss, key level miss, setdefault

def test_doesnt_execute_multiple_times_when_multithreading(self):
class Lock:

count = 0

def __enter__(self):
Lock.count += 1

def __exit__(self, *exc):
pass
def _long_func(*args, **kwargs):
time.sleep(1)
return self.func(*args, **kwargs)

cache = self.cache(5)
wrapper = cachetools.cached(cache, lock=Lock())(_long_func)

self.assertEqual(len(cache), 0)
self.assertEqual(wrapper.__wrapped__, _long_func)
with ThreadPoolExecutor(max_workers=5) as executor:
executor.map(wrapper, [1] * 10)
# only called the wrapped function once
self.assertEqual(self.func(), 1)
# Accessed cache under lock 11 times
self.assertEqual(Lock.count, 21) # 10x top level, 5x key-lvl, 5x nested, 1x setdefault
# all of our arguments were the same (1)
self.assertEqual(len(cache), 1)


class DictWrapperTest(unittest.TestCase, DecoratorTestMixin):
Expand Down