Skip to content

Commit

Permalink
fix: Chord counting of group children (#6733)
Browse files Browse the repository at this point in the history
* improv: Deconflict `chord` class and kwarg names

* improv: Make `chord.descend` protected not private

This will allow us to call it from other code in this module which needs
to accurately count chord sizes.

* fix: Counting of chord-chain tails of zero tasks

* fix: Chord counting of group children

This change ensures that we only have one piece of code which calculates
chord sizes (ie. `_chord._descend()`, recently made protected so other
canvas classes can use it as required). By doing so, we fix some edge
cases in the chord counting logic which was being used for children of
groups, and also add some unit tests to capture those cases and their
expected behaviours.

This change also introduces an integration test which checks the current
behaviour of chains used as chord bodies when nested in groups. Due to
some misbehaviour, likely with promise fulfillment, the `GroupResult`
object will time out unless all of its children are resolved prior to
`GroupResult` being joined (specifically, native joins block forever or
until timeout). This misbehaviour is tracked by #6734 and the test in
not marked as `xfail`ing to ensure that the current janky behaviour
continues to work as expected rather than regressing.
  • Loading branch information
maybe-sybr committed Apr 28, 2021
1 parent 230c9ac commit ce8a903
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 18 deletions.
50 changes: 33 additions & 17 deletions celery/canvas.py
Expand Up @@ -1170,21 +1170,25 @@ def _apply_tasks(self, tasks, producer=None, app=None, p=None,
# we are able to tell when we are at the end by checking if
# next_task is None. This enables us to set the chord size
# without burning through the entire generator. See #3021.
chord_size = 0
for task_index, (current_task, next_task) in enumerate(
lookahead(tasks)
):
# We expect that each task must be part of the same group which
# seems sensible enough. If that's somehow not the case we'll
# end up messing up chord counts and there are all sorts of
# awful race conditions to think about. We'll hope it's not!
sig, res, group_id = current_task
_chord = sig.options.get("chord") or chord
if _chord is not None and next_task is None:
chord_size = task_index + 1
if isinstance(sig, _chain):
if sig.tasks[-1].subtask_type == 'chord':
chord_size = sig.tasks[-1].__length_hint__()
else:
chord_size = task_index + len(sig.tasks[-1])
chord_obj = sig.options.get("chord") or chord
# We need to check the chord size of each contributing task so
# that when we get to the final one, we can correctly set the
# size in the backend and the chord can be sensible completed.
chord_size += _chord._descend(sig)
if chord_obj is not None and next_task is None:
# Per above, sanity check that we only saw one group
app.backend.set_chord_size(group_id, chord_size)
sig.apply_async(producer=producer, add_to_parent=False,
chord=_chord, args=args, kwargs=kwargs,
chord=chord_obj, args=args, kwargs=kwargs,
**options)
# adding callback to result, such that it will gradually
# fulfill the barrier.
Expand Down Expand Up @@ -1296,8 +1300,8 @@ def app(self):
return app if app is not None else current_app


@Signature.register_type()
class chord(Signature):
@Signature.register_type(name="chord")
class _chord(Signature):
r"""Barrier synchronization primitive.
A chord consists of a header and a body.
Expand Down Expand Up @@ -1415,20 +1419,27 @@ def apply(self, args=None, kwargs=None,
)

@classmethod
def __descend(cls, sig_obj):
def _descend(cls, sig_obj):
# Sometimes serialized signatures might make their way here
if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict):
sig_obj = Signature.from_dict(sig_obj)
if isinstance(sig_obj, group):
# Each task in a group counts toward this chord
subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks)
return sum(cls.__descend(task) for task in subtasks)
return sum(cls._descend(task) for task in subtasks)
elif isinstance(sig_obj, _chain):
# The last element in a chain counts toward this chord
return cls.__descend(sig_obj.tasks[-1])
# The last non-empty element in a chain counts toward this chord
for child_sig in sig_obj.tasks[-1::-1]:
child_size = cls._descend(child_sig)
if child_size > 0:
return child_size
else:
# We have to just hope this chain is part of some encapsulating
# signature which is valid and can fire the chord body
return 0
elif isinstance(sig_obj, chord):
# The child chord's body counts toward this chord
return cls.__descend(sig_obj.body)
return cls._descend(sig_obj.body)
elif isinstance(sig_obj, Signature):
# Each simple signature counts as 1 completion for this chord
return 1
Expand All @@ -1437,7 +1448,7 @@ def __descend(cls, sig_obj):

def __length_hint__(self):
tasks = getattr(self.tasks, "tasks", self.tasks)
return sum(self.__descend(task) for task in tasks)
return sum(self._descend(task) for task in tasks)

def run(self, header, body, partial_args, app=None, interval=None,
countdown=1, max_retries=None, eager=False,
Expand Down Expand Up @@ -1537,6 +1548,11 @@ def _get_app(self, body=None):
body = getitem_property('kwargs.body', 'Body task of chord.')


# Add a back-compat alias for the previous `chord` class name which conflicts
# with keyword arguments elsewhere in this file
chord = _chord


def signature(varies, *args, **kwargs):
"""Create new signature.
Expand Down
106 changes: 106 additions & 0 deletions t/integration/test_canvas.py
Expand Up @@ -704,6 +704,112 @@ def test_nested_group_group(self, manager):
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]

def test_nested_group_chord_counting_simple(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

gchild_sig = identity.si(42)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[42]]

def test_nested_group_chord_counting_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

gchild_count = 42
gchild_sig = chain((identity.si(1337), ) * gchild_count)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[1337]]

def test_nested_group_chord_counting_group(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

gchild_count = 42
gchild_sig = group((identity.si(1337), ) * gchild_count)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[1337] * gchild_count]

def test_nested_group_chord_counting_chord(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

gchild_count = 42
gchild_sig = chord(
(identity.si(1337), ) * gchild_count, identity.si(31337),
)
child_chord = chord((gchild_sig, ), identity.s())
group_sig = group((child_chord, ))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected
assert res.get(timeout=TIMEOUT) == [[31337]]

def test_nested_group_chord_counting_mixed(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

gchild_count = 42
child_chord = chord(
(
identity.si(42),
chain((identity.si(42), ) * gchild_count),
group((identity.si(42), ) * gchild_count),
chord((identity.si(42), ) * gchild_count, identity.si(1337)),
),
identity.s(),
)
group_sig = group((child_chord, ))
res = group_sig.delay()
# Wait for the result to land and confirm its value is as expected. The
# group result gets unrolled into the encapsulating chord, hence the
# weird unpacking below
assert res.get(timeout=TIMEOUT) == [
[42, 42, *((42, ) * gchild_count), 1337]
]

@pytest.mark.xfail(raises=TimeoutError, reason="#6734")
def test_nested_group_chord_body_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

child_chord = chord(identity.si(42), chain((identity.s(), )))
group_sig = group((child_chord, ))
res = group_sig.delay()
# The result can be expected to timeout since it seems like its
# underlying promise might not be getting fulfilled (ref #6734). Pick a
# short timeout since we don't want to block for ages and this is a
# fairly simple signature which should run pretty quickly.
expected_result = [[42]]
with pytest.raises(TimeoutError) as expected_excinfo:
res.get(timeout=TIMEOUT / 10)
# Get the child `AsyncResult` manually so that we don't have to wait
# again for the `GroupResult`
assert res.children[0].get(timeout=TIMEOUT) == expected_result[0]
assert res.get(timeout=TIMEOUT) == expected_result
# Re-raise the expected exception so this test will XFAIL
raise expected_excinfo.value


def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
root_id, parent_id, value = r.get(timeout=TIMEOUT)
Expand Down

0 comments on commit ce8a903

Please sign in to comment.