diff --git a/celery/backends/database/session.py b/celery/backends/database/session.py index 783190d757..23a205640c 100644 --- a/celery/backends/database/session.py +++ b/celery/backends/database/session.py @@ -17,6 +17,8 @@ __all__ = ('SessionManager',) +PREPARE_MODELS_MAX_RETRIES = 10 + def _after_fork_cleanup_session(session): session._after_fork() @@ -62,13 +64,12 @@ def prepare_models(self, engine): # create them, which is a race condition. If it raises an error # in one iteration, the next may pass all the existence checks # and the call will succeed. - max_retries = 10 retries = 0 while True: try: ResultModelBase.metadata.create_all(engine) except DatabaseError: - if retries < max_retries: + if retries < PREPARE_MODELS_MAX_RETRIES: sleep_amount_ms = get_exponential_backoff_interval( 10, retries, 1000, True ) diff --git a/t/unit/backends/test_database.py b/t/unit/backends/test_database.py index 5cb6741fd3..2e311d6dc5 100644 --- a/t/unit/backends/test_database.py +++ b/t/unit/backends/test_database.py @@ -14,13 +14,15 @@ import sqlalchemy # noqa except ImportError: DatabaseBackend = Task = TaskSet = retry = None # noqa - SessionManager = session_cleanup = None # noqa + SessionManager = ResultModelBase = session_cleanup = None # noqa else: from celery.backends.database import ( DatabaseBackend, retry, session_cleanup, ) from celery.backends.database import session - from celery.backends.database.session import SessionManager + from celery.backends.database.session import ( + PREPARE_MODELS_MAX_RETRIES, ResultModelBase, SessionManager + ) from celery.backends.database.models import Task, TaskSet @@ -411,3 +413,28 @@ def test_coverage_madness(self): SessionManager() finally: session.register_after_fork = prev + + @patch('celery.backends.database.session.create_engine') + def test_prepare_models_terminates(self, create_engine): + """SessionManager.prepare_models has retry logic because the creation + of database tables by multiple workers is racy. This test patches + the used method to always raise, so we can verify that it does + eventually terminate. + """ + from sqlalchemy.dialects.sqlite import dialect + from sqlalchemy.exc import DatabaseError + + sqlite = dialect.dbapi() + manager = SessionManager() + engine = manager.get_engine('dburi') + + def raise_err(bind): + raise DatabaseError("", "", [], sqlite.DatabaseError) + + patch_create_all = patch.object( + ResultModelBase.metadata, 'create_all', side_effect=raise_err) + + with pytest.raises(DatabaseError), patch_create_all as mock_create_all: + manager.prepare_models(engine) + + assert mock_create_all.call_count == PREPARE_MODELS_MAX_RETRIES + 1