Skip to content

Commit

Permalink
Fix a race condition when publishing a very large chord header (#5850)
Browse files Browse the repository at this point in the history
* Added a test case which artificially introduces a delay to group.save().

* Fix race condition by delaying the task only after saving the group.
  • Loading branch information
thedrow authored and auvipy committed Dec 1, 2019
1 parent e0ac7a1 commit a537c2d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
6 changes: 3 additions & 3 deletions celery/canvas.py
Expand Up @@ -1190,7 +1190,7 @@ def freeze(self, _id=None, group_id=None, chord=None,
try:
gid = opts['task_id']
except KeyError:
gid = opts['task_id'] = uuid()
gid = opts['task_id'] = group_id or uuid()
if group_id:
opts['group_id'] = group_id
if chord:
Expand Down Expand Up @@ -1394,8 +1394,7 @@ def run(self, header, body, partial_args, app=None, interval=None,
options.pop('chord', None)
options.pop('task_id', None)

header.freeze(group_id=group_id, chord=body, root_id=root_id)
header_result = header(*partial_args, task_id=group_id, **options)
header_result = header.freeze(group_id=group_id, chord=body, root_id=root_id)

if len(header_result) > 0:
app.backend.apply_chord(
Expand All @@ -1405,6 +1404,7 @@ def run(self, header, body, partial_args, app=None, interval=None,
countdown=countdown,
max_retries=max_retries,
)
header_result = header(*partial_args, task_id=group_id, **options)
# The execution of a chord body is normally triggered by its header's
# tasks completing. If the header is empty this will never happen, so
# we execute the body manually here.
Expand Down
63 changes: 48 additions & 15 deletions t/integration/test_canvas.py
@@ -1,10 +1,12 @@
from __future__ import absolute_import, unicode_literals

from datetime import datetime, timedelta
from time import sleep

import pytest

from celery import chain, chord, group, signature
from celery.backends.base import BaseKeyValueStoreBackend
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, GroupResult, ResultSet

Expand All @@ -31,22 +33,22 @@ class test_link_error:
@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_link_error_eager(self):
exception = ExpectedException("Task expected to fail", "test")
result = fail.apply(args=("test", ), link_error=return_exception.s())
result = fail.apply(args=("test",), link_error=return_exception.s())
actual = result.get(timeout=TIMEOUT, propagate=False)
assert actual == exception

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_link_error(self):
exception = ExpectedException("Task expected to fail", "test")
result = fail.apply(args=("test", ), link_error=return_exception.s())
result = fail.apply(args=("test",), link_error=return_exception.s())
actual = result.get(timeout=TIMEOUT, propagate=False)
assert actual == exception

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_link_error_callback_error_callback_retries_eager(self):
exception = ExpectedException("Task expected to fail", "test")
result = fail.apply(
args=("test", ),
args=("test",),
link_error=retry_once.s(countdown=None)
)
assert result.get(timeout=TIMEOUT, propagate=False) == exception
Expand All @@ -55,30 +57,32 @@ def test_link_error_callback_error_callback_retries_eager(self):
def test_link_error_callback_retries(self):
exception = ExpectedException("Task expected to fail", "test")
result = fail.apply_async(
args=("test", ),
args=("test",),
link_error=retry_once.s(countdown=None)
)
assert result.get(timeout=TIMEOUT, propagate=False) == exception

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_link_error_using_signature_eager(self):
fail = signature('t.integration.tasks.fail', args=("test", ))
fail = signature('t.integration.tasks.fail', args=("test",))
retrun_exception = signature('t.integration.tasks.return_exception')

fail.link_error(retrun_exception)

exception = ExpectedException("Task expected to fail", "test")
assert (fail.apply().get(timeout=TIMEOUT, propagate=False), True) == (exception, True)
assert (fail.apply().get(timeout=TIMEOUT, propagate=False), True) == (
exception, True)

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_link_error_using_signature(self):
fail = signature('t.integration.tasks.fail', args=("test", ))
fail = signature('t.integration.tasks.fail', args=("test",))
retrun_exception = signature('t.integration.tasks.return_exception')

fail.link_error(retrun_exception)

exception = ExpectedException("Task expected to fail", "test")
assert (fail.delay().get(timeout=TIMEOUT, propagate=False), True) == (exception, True)
assert (fail.delay().get(timeout=TIMEOUT, propagate=False), True) == (
exception, True)


class test_chain:
Expand All @@ -97,8 +101,8 @@ def test_single_chain(self, manager):
def test_complex_chain(self, manager):
c = (
add.s(2, 2) | (
add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
) |
add.s(4) | add_replaced.s(8) | add.s(16) | add.s(32)
) |
group(add.s(i) for i in range(4))
)
res = c()
Expand Down Expand Up @@ -210,7 +214,8 @@ def test_second_order_replace(self, manager):
redis_connection.lrange('redis-echo', 0, -1)
))

expected_messages = [b'In A', b'In B', b'In/Out C', b'Out B', b'Out A']
expected_messages = [b'In A', b'In B', b'In/Out C', b'Out B',
b'Out A']
assert redis_messages == expected_messages

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
Expand Down Expand Up @@ -311,7 +316,8 @@ def test_chain_of_task_a_group_and_a_chord(self, manager):
assert res.get(timeout=TIMEOUT) == 8

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_chain_of_chords_as_groups_chained_to_a_task_with_two_tasks(self, manager):
def test_chain_of_chords_as_groups_chained_to_a_task_with_two_tasks(self,
manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
Expand Down Expand Up @@ -514,6 +520,31 @@ def assert_ping(manager):


class test_chord:
@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_simple_chord_with_a_delay_in_group_save(self, manager, monkeypatch):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

if not isinstance(manager.app.backend, BaseKeyValueStoreBackend):
raise pytest.skip("The delay may only occur in key/value backends")

x = manager.app.backend._apply_chord_incr

def apply_chord_incr_with_sleep(*args, **kwargs):
sleep(1)
x(*args, **kwargs)

monkeypatch.setattr(BaseKeyValueStoreBackend,
'_apply_chord_incr',
apply_chord_incr_with_sleep)

c = group(add.si(1, 1), add.si(1, 1)) | tsum.s()

result = c()
assert result.get() == 4

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_redis_subscribed_channels_leak(self, manager):
if not manager.app.conf.result_backend.startswith('redis'):
Expand Down Expand Up @@ -541,7 +572,7 @@ def test_redis_subscribed_channels_leak(self, manager):
# (existing from previous tests).
chord_header_task_count = 2
assert channels_before_count <= \
chord_header_task_count * total_chords + initial_channels_count
chord_header_task_count * total_chords + initial_channels_count

result_values = [
result.get(timeout=TIMEOUT)
Expand Down Expand Up @@ -911,7 +942,8 @@ def test_chain_to_a_chord_with_large_header(self, manager):
except NotImplementedError as e:
raise pytest.skip(e.args[0])

c = identity.si(1) | group(identity.s() for _ in range(1000)) | tsum.s()
c = identity.si(1) | group(
identity.s() for _ in range(1000)) | tsum.s()
res = c.delay()
assert res.get(timeout=TIMEOUT) == 1000

Expand All @@ -922,5 +954,6 @@ def test_priority(self, manager):

@pytest.mark.flaky(reruns=5, reruns_delay=1, cause=is_retryable_exception)
def test_priority_chain(self, manager):
c = return_priority.signature(priority=3) | return_priority.signature(priority=5)
c = return_priority.signature(priority=3) | return_priority.signature(
priority=5)
assert c().get(timeout=TIMEOUT) == "Priority: 5"

0 comments on commit a537c2d

Please sign in to comment.