Skip to content

Commit

Permalink
Merge pull request #537 from JWCook/concurrency-tests
Browse files Browse the repository at this point in the history
Fix some thread safety issues and improve concurrency tests
  • Loading branch information
JWCook committed Feb 23, 2022
2 parents 6e595f3 + 3c77e2c commit cc5305f
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 171 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/deploy.yml
Expand Up @@ -74,8 +74,8 @@ jobs:
run: |
source $VENV
pytest -x ${{ env.XDIST_ARGS }} tests/unit
pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not multithreaded'
STRESS_TEST_MULTIPLIER=5 pytest tests/integration -k 'multithreaded'
pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not concurrency'
STRESS_TEST_MULTIPLIER=10 pytest tests/integration -k 'concurrency'
# Deploy stable builds on tags only, and pre-release builds from manual trigger ("workflow_dispatch")
release:
Expand Down
8 changes: 7 additions & 1 deletion HISTORY.md
@@ -1,10 +1,16 @@
# History

## 0.9.3 (Unreleased)
## 0.9.3 (2022-02-22)
* Fix handling BSON serializer differences between pymongo's `bson` and standalone `bson` codec.
* Handle `CorruptGridFile` error in GridFS backend
* Fix cache path expansion for user directories (`~/...`) for SQLite and filesystem backends
* Fix request normalization for request body with a list as a JSON root
* Skip normalizing a JSON request body if it's excessively large (>10MB) due to performance impact
* Fix some thread safety issues:
* Fix race condition in SQLite backend with dropping and recreating tables in multiple threads
* Fix race condition in filesystem backend when one thread deletes a file after it's opened but
before it is read by a different thread
* Fix multiple race conditions in GridFS backend

## 0.9.2 (2022-02-15)
* Fix serialization in filesystem backend with binary content that is also valid UTF-8
Expand Down
14 changes: 13 additions & 1 deletion noxfile.py
@@ -1,7 +1,8 @@
"""Notes:
* 'test' command: nox will use poetry.lock to determine dependency versions
* 'test-<python version>' commands: nox will use poetry.lock to determine dependency versions
* 'lint' command: tools and environments are managed by pre-commit
* All other commands: the current environment will be used instead of creating new ones
* Run `nox -l` to see all available commands
"""
from os.path import join
from shutil import rmtree
Expand All @@ -19,6 +20,7 @@

UNIT_TESTS = join('tests', 'unit')
INTEGRATION_TESTS = join('tests', 'integration')
STRESS_TEST_MULTIPLIER = 10
COVERAGE_ARGS = (
'--cov --cov-report=term --cov-report=html' # Generate HTML + stdout coverage report
)
Expand Down Expand Up @@ -52,6 +54,16 @@ def coverage(session):
session.run(*cmd_2.split(' '))


@session(python=False, name='stress')
def stress_test(session):
"""Run concurrency tests with a higher stress test multiplier"""
cmd = f'pytest {INTEGRATION_TESTS} -rs -k concurrency'
session.run(
*cmd.split(' '),
env={'STRESS_TEST_MULTIPLIER': str(STRESS_TEST_MULTIPLIER)},
)


@session(python=False)
def docs(session):
"""Build Sphinx documentation"""
Expand Down
238 changes: 122 additions & 116 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions requests_cache/backends/filesystem.py
Expand Up @@ -40,6 +40,7 @@
from pathlib import Path
from pickle import PickleError
from shutil import rmtree
from threading import RLock
from typing import Iterator

from ..serializers import SERIALIZERS
Expand Down Expand Up @@ -81,6 +82,10 @@ def clear(self):
self.responses.clear()
self.redirects.init_db()

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock:
return super().remove_expired_responses(*args, **kwargs)


class FileDict(BaseStorage):
"""A dictionary-like interface to files on the local filesystem"""
Expand All @@ -97,14 +102,16 @@ def __init__(
self.cache_dir = get_cache_path(cache_name, use_cache_dir=use_cache_dir, use_temp=use_temp)
self.extension = _get_extension(extension, self.serializer)
self.is_binary = getattr(self.serializer, 'is_binary', False)
self._lock = RLock()
makedirs(self.cache_dir, exist_ok=True)

@contextmanager
def _try_io(self, ignore_errors: bool = False):
"""Attempt an I/O operation, and either ignore errors or re-raise them as KeyErrors"""
try:
yield
except (IOError, OSError, PickleError) as e:
with self._lock:
yield
except (EOFError, IOError, OSError, PickleError) as e:
if not ignore_errors:
raise KeyError(e)

Expand Down Expand Up @@ -142,7 +149,8 @@ def keys(self):

def paths(self) -> Iterator[Path]:
"""Get absolute file paths to all cached responses"""
return self.cache_dir.glob(f'*{self.extension}')
with self._lock:
return self.cache_dir.glob(f'*{self.extension}')


def _get_extension(extension: str = None, serializer=None) -> str:
Expand Down
44 changes: 32 additions & 12 deletions requests_cache/backends/gridfs.py
Expand Up @@ -11,13 +11,19 @@
:classes-only:
:nosignatures:
"""
from logging import getLogger
from threading import RLock

from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
from pymongo import MongoClient

from .._utils import get_valid_kwargs
from .base import BaseCache, BaseStorage
from .mongodb import MongoDict

logger = getLogger(__name__)


class GridFSCache(BaseCache):
"""GridFS cache backend.
Expand All @@ -39,6 +45,10 @@ def __init__(self, db_name: str, **kwargs):
db_name, collection_name='redirects', connection=self.responses.connection, **kwargs
)

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock:
return super().remove_expired_responses(*args, **kwargs)


class GridFSPickleDict(BaseStorage):
"""A dictionary-like interface for a GridFS database
Expand All @@ -56,27 +66,37 @@ def __init__(self, db_name, collection_name=None, connection=None, **kwargs):
self.connection = connection or MongoClient(**connection_kwargs)
self.db = self.connection[db_name]
self.fs = GridFS(self.db)
self._lock = RLock()

def __getitem__(self, key):
result = self.fs.find_one({'_id': key})
if result is None:
try:
with self._lock:
result = self.fs.find_one({'_id': key})
if result is None:
raise KeyError
return self.serializer.loads(result.read())
except CorruptGridFile as e:
logger.warning(e, exc_info=True)
raise KeyError
return self.serializer.loads(result.read())

def __setitem__(self, key, item):
try:
self.__delitem__(key)
except KeyError:
pass
value = self.serializer.dumps(item)
encoding = None if isinstance(value, bytes) else 'utf-8'
self.fs.put(value, encoding=encoding, **{'_id': key})

with self._lock:
try:
self.fs.delete(key)
self.fs.put(value, encoding=encoding, **{'_id': key})
# This can happen because GridFS is not thread-safe for concurrent writes
except FileExists as e:
logger.warning(e, exc_info=True)

def __delitem__(self, key):
res = self.fs.find_one({'_id': key})
if res is None:
raise KeyError
self.fs.delete(res._id)
with self._lock:
res = self.fs.find_one({'_id': key})
if res is None:
raise KeyError
self.fs.delete(res._id)

def __len__(self):
return self.db['fs.files'].estimated_document_count()
Expand Down
13 changes: 9 additions & 4 deletions requests_cache/backends/sqlite.py
Expand Up @@ -131,6 +131,10 @@ def clear(self):
self.responses.init_db()
self.redirects.init_db()

def remove_expired_responses(self, *args, **kwargs):
with self.responses._lock, self.redirects._lock:
return super().remove_expired_responses(*args, **kwargs)


class SQLiteDict(BaseStorage):
"""A dictionary-like interface for SQLite"""
Expand Down Expand Up @@ -248,10 +252,11 @@ def bulk_delete(self, keys=None, values=None):
con.execute(statement, args)

def clear(self):
with self.connection(commit=True) as con:
con.execute(f'DROP TABLE IF EXISTS {self.table_name}')
self.init_db()
self.vacuum()
with self._lock:
with self.connection(commit=True) as con:
con.execute(f'DROP TABLE IF EXISTS {self.table_name}')
self.init_db()
self.vacuum()

def vacuum(self):
with self.connection(commit=True) as con:
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Expand Up @@ -27,8 +27,9 @@

# Allow running longer stress tests with an environment variable
STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1'))
N_THREADS = 2 * STRESS_TEST_MULTIPLIER
N_WORKERS = 2 * STRESS_TEST_MULTIPLIER
N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER
N_REQUESTS_PER_ITERATION = 10 + 10 * STRESS_TEST_MULTIPLIER

HTTPBIN_METHODS = ['GET', 'POST', 'PUT', 'PATCH', 'DELETE']
HTTPBIN_FORMATS = [
Expand Down
79 changes: 48 additions & 31 deletions tests/integration/base_cache_test.py
@@ -1,8 +1,11 @@
"""Common tests to run for all backends (BaseCache subclasses)"""
import json
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime
from functools import partial
from io import BytesIO
from threading import Thread
from logging import getLogger
from random import randint
from time import sleep, time
from typing import Dict, Type
from unittest.mock import MagicMock, patch
Expand All @@ -24,12 +27,15 @@
HTTPDATE_STR,
LAST_MODIFIED,
N_ITERATIONS,
N_THREADS,
N_REQUESTS_PER_ITERATION,
N_WORKERS,
USE_PYTEST_HTTPBIN,
assert_delta_approx_equal,
httpbin,
)

logger = getLogger(__name__)

# Handle optional dependencies if they're not installed; if so, skips will be shown in pytest output
TEST_SERIALIZERS = SERIALIZERS.copy()
try:
Expand Down Expand Up @@ -81,7 +87,7 @@ def test_all_response_formats(self, response_format, serializer):
pytest.skip(f'Dependencies not installed for {serializer}')

session = self.init_session(serializer=serializer)
# Temporary workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60
# Workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60
if response_format == 'json' and USE_PYTEST_HTTPBIN:
session.allowable_codes = (200, 404)

Expand Down Expand Up @@ -226,7 +232,7 @@ def test_conditional_request__max_age_0(self, cache_headers, validator_headers):
'validator_headers', [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
)
@pytest.mark.parametrize('cache_headers', [{'Cache-Control': 'max-age=0'}])
def test_conditional_request_refreshenes_expire_date(self, cache_headers, validator_headers):
def test_conditional_request_refreshes_expire_date(self, cache_headers, validator_headers):
"""Test that revalidation attempt with 304 responses causes stale entry to become fresh again considering
Cache-Control header of the 304 response."""
url = httpbin('response-headers')
Expand Down Expand Up @@ -300,33 +306,6 @@ def test_remove_expired_responses(self):
assert not session.cache.has_url(httpbin('redirect/1'))
assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS])

@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_multithreaded(self, iteration):
"""Run a multi-threaded stress test for each backend"""
session = self.init_session()
start = time()
url = httpbin('anything')

def send_requests():
for i in range(N_ITERATIONS):
session.get(url, params={f'key_{i}': f'value_{i}'})

threads = [Thread(target=send_requests) for i in range(N_THREADS)]
for t in threads:
t.start()
for t in threads:
t.join()

elapsed = time() - start
average = (elapsed * 1000) / (N_ITERATIONS * N_THREADS)
print(
f'{self.backend_class}: Ran {N_ITERATIONS} iterations with {N_THREADS} threads each in {elapsed} s'
)
print(f'Average time per request: {average} ms')

for i in range(N_ITERATIONS):
assert session.cache.has_url(f'{url}?key_{i}=value_{i}')

@pytest.mark.parametrize('method', HTTPBIN_METHODS)
def test_filter_request_headers(self, method):
url = httpbin(method.lower())
Expand Down Expand Up @@ -364,3 +343,41 @@ def test_filter_request_post_data(self, post_type):
elif post_type == 'json':
body = json.loads(response.request.body)
assert "api_key" not in body

@pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_concurrency(self, iteration, executor_class):
"""Run multithreaded and multiprocess stress tests for each backend.
The number of workers (thread/processes), iterations, and requests per iteration can be
increased via the `STRESS_TEST_MULTIPLIER` environment variable.
"""
start = time()
url = httpbin('anything')
self.init_session(clear=True)

session_factory = partial(self.init_session, clear=False)
request_func = partial(_send_request, session_factory, url)
with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
_ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))

# Some logging for debug purposes
elapsed = time() - start
average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS)
worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes'
logger.info(
f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with '
f'{N_WORKERS} {worker_type} in {elapsed} s\n'
f'Average time per request: {average} ms'
)


def _send_request(session_factory, url, _=None):
"""Concurrent request function for stress tests. Defined in module scope so it can be serialized
to multiple processes.
"""
# Use fewer unique requests/cached responses than total iterations, so we get some cache hits
n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
i = randint(1, n_unique_responses)

session = session_factory()
return session.get(url, params={f'key_{i}': f'value_{i}'})
18 changes: 18 additions & 0 deletions tests/integration/test_mongodb.py
@@ -1,6 +1,8 @@
from unittest.mock import patch

import pytest
from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
from pymongo import MongoClient

from requests_cache._utils import get_valid_kwargs
Expand Down Expand Up @@ -65,6 +67,22 @@ def test_connection_kwargs(self, mock_get_valid_kwargs, mock_client, mock_gridfs
GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
mock_client.assert_called_with(host='http://0.0.0.0', port=1234)

def test_corrupt_file(self):
"""A corrupted file should be handled and raise a KeyError instead"""
cache = self.init_cache()
cache['key'] = 'value'
with pytest.raises(KeyError), patch.object(GridFS, 'find_one', side_effect=CorruptGridFile):
cache['key']

def test_file_exists(self):
cache = self.init_cache()

# This write should just quiety fail
with patch.object(GridFS, 'put', side_effect=FileExists):
cache['key'] = 'value_1'

assert 'key' not in cache


class TestGridFSCache(BaseCacheTest):
backend_class = GridFSCache

0 comments on commit cc5305f

Please sign in to comment.