Skip to content

Commit

Permalink
fix: Count chord "final elements" correctly
Browse files Browse the repository at this point in the history
This change amends the implementation of `chord.__length_hint__()` to
ensure that all child task types are correctly counted. Specifically:

 * all sub-tasks of a group are counted recursively
 * the final task of a chain is counted recursively
 * the body of a chord is counted recursively
 * all other simple signatures count as a single "final element"

There is also a deserialisation step if a `dict` is seen while counting
the final elements in a chord, however this should become less important
with the merge of #6342 which ensures that tasks are recursively
deserialized by `.from_dict()`.
  • Loading branch information
maybe-sybr committed Sep 18, 2020
1 parent 00500f9 commit 71f0f9e
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 19 deletions.
34 changes: 21 additions & 13 deletions celery/canvas.py
Expand Up @@ -1363,21 +1363,29 @@ def apply(self, args=None, kwargs=None,
args=(tasks.apply(args, kwargs).get(propagate=propagate),),
)

def _traverse_tasks(self, tasks, value=None):
stack = deque(tasks)
while stack:
task = stack.popleft()
if isinstance(task, group):
stack.extend(task.tasks)
elif isinstance(task, _chain) and isinstance(task.tasks[-1], group):
stack.extend(task.tasks[-1].tasks)
else:
yield task if value is None else value
@classmethod
def __descend(cls, sig_obj):
# Sometimes serialized signatures might make their way here
if isinstance(sig_obj, dict):
sig_obj = Signature.from_dict(sig_obj)
if isinstance(sig_obj, group):
# Each task in a group counts toward this chord
subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks)
return sum(cls.__descend(task) for task in subtasks)
elif isinstance(sig_obj, _chain):
# The last element in a chain counts toward this chord
return cls.__descend(sig_obj.tasks[-1])
elif isinstance(sig_obj, chord):
# The child chord's body counts toward this chord
return cls.__descend(sig_obj.body)
elif isinstance(sig_obj, Signature):
# Each simple signature counts as 1 completion for this chord
return 1
# Any other types are assumed to be iterables of simple signatures
return len(sig_obj)

def __length_hint__(self):
tasks = (self.tasks.tasks if isinstance(self.tasks, group)
else self.tasks)
return sum(self._traverse_tasks(tasks, 1))
return sum(self.__descend(task) for task in self.tasks)

def run(self, header, body, partial_args, app=None, interval=None,
countdown=1, max_retries=None, eager=False,
Expand Down
22 changes: 22 additions & 0 deletions t/integration/test_canvas.py
Expand Up @@ -991,3 +991,25 @@ def test_priority_chain(self, manager):
c = return_priority.signature(priority=3) | return_priority.signature(
priority=5)
assert c().get(timeout=TIMEOUT) == "Priority: 5"

def test_nested_chord_group_chain_group_tail(self, manager):
"""
Sanity check that a deeply nested group is completed as expected.
Groups at the end of chains nested in chords have had issues and this
simple test sanity check that such a tsk structure can be completed.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chord(group(chain(
identity.s(42),
group(
identity.s(),
identity.s(),
),
)), identity.s())
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]
139 changes: 133 additions & 6 deletions t/unit/tasks/test_canvas.py
Expand Up @@ -743,12 +743,139 @@ def test_app_fallback_to_current(self):
x = chord([t1], body=t1)
assert x.app is current_app

def test_chord_size_with_groups(self):
x = chord([
self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]),
self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]),
], body=self.add.si(2, 2))
assert x.__length_hint__() == 4
def test_chord_size_simple(self):
sig = chord(self.add.s())
assert sig.__length_hint__() == 1

def test_chord_size_with_body(self):
sig = chord(self.add.s(), self.add.s())
assert sig.__length_hint__() == 1

def test_chord_size_explicit_group_single(self):
sig = chord(group(self.add.s()))
assert sig.__length_hint__() == 1

def test_chord_size_explicit_group_many(self):
sig = chord(group([self.add.s()] * 42))
assert sig.__length_hint__() == 42

def test_chord_size_implicit_group_single(self):
sig = chord([self.add.s()])
assert sig.__length_hint__() == 1

def test_chord_size_implicit_group_many(self):
sig = chord([self.add.s()] * 42)
assert sig.__length_hint__() == 42

def test_chord_size_chain_single(self):
sig = chord(chain(self.add.s()))
assert sig.__length_hint__() == 1

def test_chord_size_chain_many(self):
# Chains get flattened into the encapsulating chord so even though the
# chain would only count for 1, the tasks we pulled into the chord's
# header and are counted as a bunch of simple signature objects
sig = chord(chain([self.add.s()] * 42))
assert sig.__length_hint__() == 42

def test_chord_size_nested_chain_chain_single(self):
sig = chord(chain(chain(self.add.s())))
assert sig.__length_hint__() == 1

def test_chord_size_nested_chain_chain_many(self):
# The outer chain will be pulled up into the chord but the lower one
# remains and will only count as a single final element
sig = chord(chain(chain([self.add.s()] * 42)))
assert sig.__length_hint__() == 1

def test_chord_size_implicit_chain_single(self):
sig = chord([self.add.s()])
assert sig.__length_hint__() == 1

def test_chord_size_implicit_chain_many(self):
# This isn't a chain object so the `tasks` attribute can't be lifted
# into the chord - this isn't actually valid and would blow up we tried
# to run it but it sanity checks our recursion
sig = chord([[self.add.s()] * 42])
assert sig.__length_hint__() == 1

def test_chord_size_nested_implicit_chain_chain_single(self):
sig = chord([chain(self.add.s())])
assert sig.__length_hint__() == 1

def test_chord_size_nested_implicit_chain_chain_many(self):
sig = chord([chain([self.add.s()] * 42)])
assert sig.__length_hint__() == 1

# Nested groups in a chain only affect the chord size if they are the last
# element in the chain - in that case each group element is counted
def test_chord_size_nested_group_chain_group_head_single(self):
x = chord(
group(
[group(self.add.s()) | self.add.s()] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_head_many(self):
x = chord(
group(
[group([self.add.s()] * 4) | self.add.s()] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 2

def test_chord_size_nested_group_chain_group_mid_single(self):
x = chord(
group(
[self.add.s() | group(self.add.s()) | self.add.s()] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_mid_many(self):
x = chord(
group(
[self.add.s() | group([self.add.s()] * 4) | self.add.s()] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 2

def test_chord_size_nested_group_chain_group_tail_single(self):
x = chord(
group(
[self.add.s() | group(self.add.s())] * 42
),
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_group_chain_group_tail_many(self):
x = chord(
group(
[self.add.s() | group([self.add.s()] * 4)] * 2
),
body=self.add.s()
)
assert x.__length_hint__() == 4 * 2

def test_chord_size_nested_implicit_group_chain_group_tail_single(self):
x = chord(
[self.add.s() | group(self.add.s())] * 42,
body=self.add.s()
)
assert x.__length_hint__() == 42

def test_chord_size_nested_implicit_group_chain_group_tail_many(self):
x = chord(
[self.add.s() | group([self.add.s()] * 4)] * 2,
body=self.add.s()
)
assert x.__length_hint__() == 4 * 2

def test_set_immutable(self):
x = chord([Mock(name='t1'), Mock(name='t2')], app=self.app)
Expand Down

0 comments on commit 71f0f9e

Please sign in to comment.