Skip to content

Commit

Permalink
Add test for prepare_models retry error condition
Browse files Browse the repository at this point in the history
  • Loading branch information
RazerM committed Sep 29, 2020
1 parent cb6582d commit 073f339
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
5 changes: 3 additions & 2 deletions celery/backends/database/session.py
Expand Up @@ -12,6 +12,8 @@

__all__ = ('SessionManager',)

PREPARE_MODELS_MAX_RETRIES = 10


def _after_fork_cleanup_session(session):
session._after_fork()
Expand Down Expand Up @@ -57,13 +59,12 @@ def prepare_models(self, engine):
# create them, which is a race condition. If it raises an error
# in one iteration, the next may pass all the existence checks
# and the call will succeed.
max_retries = 10
retries = 0
while True:
try:
ResultModelBase.metadata.create_all(engine)
except DatabaseError:
if retries < max_retries:
if retries < PREPARE_MODELS_MAX_RETRIES:
sleep_amount_ms = get_exponential_backoff_interval(
10, retries, 1000, True
)
Expand Down
28 changes: 27 additions & 1 deletion t/unit/backends/test_database.py
Expand Up @@ -13,7 +13,8 @@
from celery.backends.database import (DatabaseBackend, retry, session, # noqa
session_cleanup)
from celery.backends.database.models import Task, TaskSet # noqa
from celery.backends.database.session import SessionManager # noqa
from celery.backends.database.session import ( # noqa
PREPARE_MODELS_MAX_RETRIES, ResultModelBase, SessionManager)
from t import skip # noqa


Expand Down Expand Up @@ -398,3 +399,28 @@ def test_coverage_madness(self):
SessionManager()
finally:
session.register_after_fork = prev

@patch('celery.backends.database.session.create_engine')
def test_prepare_models_terminates(self, create_engine):
"""SessionManager.prepare_models has retry logic because the creation
of database tables by multiple workers is racy. This test patches
the used method to always raise, so we can verify that it does
eventually terminate.
"""
from sqlalchemy.dialects.sqlite import dialect
from sqlalchemy.exc import DatabaseError

sqlite = dialect.dbapi()
manager = SessionManager()
engine = manager.get_engine('dburi')

def raise_err(bind):
raise DatabaseError("", "", [], sqlite.DatabaseError)

patch_create_all = patch.object(
ResultModelBase.metadata, 'create_all', side_effect=raise_err)

with pytest.raises(DatabaseError), patch_create_all as mock_create_all:
manager.prepare_models(engine)

assert mock_create_all.call_count == PREPARE_MODELS_MAX_RETRIES + 1

0 comments on commit 073f339

Please sign in to comment.