From c06bc23b6094c3d2b9b8a24e98eecc78d8583ee3 Mon Sep 17 00:00:00 2001 From: maybe-sybr <58414429+maybe-sybr@users.noreply.github.com> Date: Fri, 18 Sep 2020 14:38:46 +1000 Subject: [PATCH] fix: Count chord "final elements" correctly This change amends the implementation of `chord.__length_hint__()` to ensure that all child task types are correctly counted. Specifically: * all sub-tasks of a group are counted recursively * the final task of a chain is counted recursively * the body of a chord is counted recursively * all other simple signatures count as a single "final element" There is also a deserialisation step if a `dict` is seen while counting the final elements in a chord, however this should become less important with the merge of #6342 which ensures that tasks are recursively deserialized by `.from_dict()`. --- celery/canvas.py | 35 ++++--- t/integration/test_canvas.py | 22 +++++ t/unit/tasks/test_canvas.py | 179 +++++++++++++++++++++++++++++++++-- 3 files changed, 217 insertions(+), 19 deletions(-) diff --git a/celery/canvas.py b/celery/canvas.py index 46bbf3ef82..d81013b7d2 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -1400,21 +1400,30 @@ def apply(self, args=None, kwargs=None, args=(tasks.apply(args, kwargs).get(propagate=propagate),), ) - def _traverse_tasks(self, tasks, value=None): - stack = deque(tasks) - while stack: - task = stack.popleft() - if isinstance(task, group): - stack.extend(task.tasks) - elif isinstance(task, _chain) and isinstance(task.tasks[-1], group): - stack.extend(task.tasks[-1].tasks) - else: - yield task if value is None else value + @classmethod + def __descend(cls, sig_obj): + # Sometimes serialized signatures might make their way here + if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict): + sig_obj = Signature.from_dict(sig_obj) + if isinstance(sig_obj, group): + # Each task in a group counts toward this chord + subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks) + return sum(cls.__descend(task) for task in subtasks) + elif isinstance(sig_obj, _chain): + # The last element in a chain counts toward this chord + return cls.__descend(sig_obj.tasks[-1]) + elif isinstance(sig_obj, chord): + # The child chord's body counts toward this chord + return cls.__descend(sig_obj.body) + elif isinstance(sig_obj, Signature): + # Each simple signature counts as 1 completion for this chord + return 1 + # Any other types are assumed to be iterables of simple signatures + return len(sig_obj) def __length_hint__(self): - tasks = (self.tasks.tasks if isinstance(self.tasks, group) - else self.tasks) - return sum(self._traverse_tasks(tasks, 1)) + tasks = getattr(self.tasks, "tasks", self.tasks) + return sum(self.__descend(task) for task in tasks) def run(self, header, body, partial_args, app=None, interval=None, countdown=1, max_retries=None, eager=False, diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 3d9b710320..a5f9d3fddf 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -1021,3 +1021,25 @@ def test_priority_chain(self, manager): c = return_priority.signature(priority=3) | return_priority.signature( priority=5) assert c().get(timeout=TIMEOUT) == "Priority: 5" + + def test_nested_chord_group_chain_group_tail(self, manager): + """ + Sanity check that a deeply nested group is completed as expected. + + Groups at the end of chains nested in chords have had issues and this + simple test sanity check that such a tsk structure can be completed. + """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + sig = chord(group(chain( + identity.s(42), # -> 42 + group( + identity.s(), # -> 42 + identity.s(), # -> 42 + ), # [42, 42] + )), identity.s()) # [[42, 42]] + res = sig.delay() + assert res.get(timeout=TIMEOUT) == [[42, 42]] diff --git a/t/unit/tasks/test_canvas.py b/t/unit/tasks/test_canvas.py index 113881efcd..a4e7b11ac6 100644 --- a/t/unit/tasks/test_canvas.py +++ b/t/unit/tasks/test_canvas.py @@ -810,12 +810,179 @@ def test_app_fallback_to_current(self): x = chord([t1], body=t1) assert x.app is current_app - def test_chord_size_with_groups(self): - x = chord([ - self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]), - self.add.s(2, 2) | group([self.add.si(2, 2), self.add.si(2, 2)]), - ], body=self.add.si(2, 2)) - assert x.__length_hint__() == 4 + def test_chord_size_simple(self): + sig = chord(self.add.s()) + assert sig.__length_hint__() == 1 + + def test_chord_size_with_body(self): + sig = chord(self.add.s(), self.add.s()) + assert sig.__length_hint__() == 1 + + def test_chord_size_explicit_group_single(self): + sig = chord(group(self.add.s())) + assert sig.__length_hint__() == 1 + + def test_chord_size_explicit_group_many(self): + sig = chord(group([self.add.s()] * 42)) + assert sig.__length_hint__() == 42 + + def test_chord_size_implicit_group_single(self): + sig = chord([self.add.s()]) + assert sig.__length_hint__() == 1 + + def test_chord_size_implicit_group_many(self): + sig = chord([self.add.s()] * 42) + assert sig.__length_hint__() == 42 + + def test_chord_size_chain_single(self): + sig = chord(chain(self.add.s())) + assert sig.__length_hint__() == 1 + + def test_chord_size_chain_many(self): + # Chains get flattened into the encapsulating chord so even though the + # chain would only count for 1, the tasks we pulled into the chord's + # header and are counted as a bunch of simple signature objects + sig = chord(chain([self.add.s()] * 42)) + assert sig.__length_hint__() == 42 + + def test_chord_size_nested_chain_chain_single(self): + sig = chord(chain(chain(self.add.s()))) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_chain_chain_many(self): + # The outer chain will be pulled up into the chord but the lower one + # remains and will only count as a single final element + sig = chord(chain(chain([self.add.s()] * 42))) + assert sig.__length_hint__() == 1 + + def test_chord_size_implicit_chain_single(self): + sig = chord([self.add.s()]) + assert sig.__length_hint__() == 1 + + def test_chord_size_implicit_chain_many(self): + # This isn't a chain object so the `tasks` attribute can't be lifted + # into the chord - this isn't actually valid and would blow up we tried + # to run it but it sanity checks our recursion + sig = chord([[self.add.s()] * 42]) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_implicit_chain_chain_single(self): + sig = chord([chain(self.add.s())]) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_implicit_chain_chain_many(self): + sig = chord([chain([self.add.s()] * 42)]) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_chord_body_simple(self): + sig = chord(chord(tuple(), self.add.s())) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_chord_body_implicit_group_single(self): + sig = chord(chord(tuple(), [self.add.s()])) + assert sig.__length_hint__() == 1 + + def test_chord_size_nested_chord_body_implicit_group_many(self): + sig = chord(chord(tuple(), [self.add.s()] * 42)) + assert sig.__length_hint__() == 42 + + # Nested groups in a chain only affect the chord size if they are the last + # element in the chain - in that case each group element is counted + def test_chord_size_nested_group_chain_group_head_single(self): + x = chord( + group( + [group(self.add.s()) | self.add.s()] * 42 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 42 + + def test_chord_size_nested_group_chain_group_head_many(self): + x = chord( + group( + [group([self.add.s()] * 4) | self.add.s()] * 2 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 2 + + def test_chord_size_nested_group_chain_group_mid_single(self): + x = chord( + group( + [self.add.s() | group(self.add.s()) | self.add.s()] * 42 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 42 + + def test_chord_size_nested_group_chain_group_mid_many(self): + x = chord( + group( + [self.add.s() | group([self.add.s()] * 4) | self.add.s()] * 2 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 2 + + def test_chord_size_nested_group_chain_group_tail_single(self): + x = chord( + group( + [self.add.s() | group(self.add.s())] * 42 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 42 + + def test_chord_size_nested_group_chain_group_tail_many(self): + x = chord( + group( + [self.add.s() | group([self.add.s()] * 4)] * 2 + ), + body=self.add.s() + ) + assert x.__length_hint__() == 4 * 2 + + def test_chord_size_nested_implicit_group_chain_group_tail_single(self): + x = chord( + [self.add.s() | group(self.add.s())] * 42, + body=self.add.s() + ) + assert x.__length_hint__() == 42 + + def test_chord_size_nested_implicit_group_chain_group_tail_many(self): + x = chord( + [self.add.s() | group([self.add.s()] * 4)] * 2, + body=self.add.s() + ) + assert x.__length_hint__() == 4 * 2 + + def test_chord_size_deserialized_element_single(self): + child_sig = self.add.s() + deserialized_child_sig = json.loads(json.dumps(child_sig)) + # We have to break in to be sure that a child remains as a `dict` so we + # can confirm that the length hint will instantiate a `Signature` + # object and then descend as expected + chord_sig = chord(tuple()) + chord_sig.tasks = [deserialized_child_sig] + with patch( + "celery.canvas.Signature.from_dict", return_value=child_sig + ) as mock_from_dict: + assert chord_sig. __length_hint__() == 1 + mock_from_dict.assert_called_once_with(deserialized_child_sig) + + def test_chord_size_deserialized_element_many(self): + child_sig = self.add.s() + deserialized_child_sig = json.loads(json.dumps(child_sig)) + # We have to break in to be sure that a child remains as a `dict` so we + # can confirm that the length hint will instantiate a `Signature` + # object and then descend as expected + chord_sig = chord(tuple()) + chord_sig.tasks = [deserialized_child_sig] * 42 + with patch( + "celery.canvas.Signature.from_dict", return_value=child_sig + ) as mock_from_dict: + assert chord_sig. __length_hint__() == 42 + mock_from_dict.assert_has_calls([call(deserialized_child_sig)] * 42) def test_set_immutable(self): x = chord([Mock(name='t1'), Mock(name='t2')], app=self.app)