Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix chord element counting #6354

Merged
merged 6 commits into from Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion celery/backends/base.py
Expand Up @@ -919,7 +919,11 @@ def on_chord_part_return(self, request, state, result, **kwargs):
ChordError(f'GroupResult {gid} no longer exists'),
)
val = self.incr(key)
size = len(deps)
# Set the chord size to the value defined in the request, or fall back
# to the number of dependencies we can see from the restored result
size = request.chord.get("chord_size")
if size is None:
size = len(deps)
if val > size: # pragma: no cover
logger.warning('Chord counter incremented too many times for %r',
gid)
Expand Down
65 changes: 46 additions & 19 deletions celery/backends/redis.py
Expand Up @@ -13,6 +13,7 @@
from celery._state import task_join_will_block
from celery.canvas import maybe_signature
from celery.exceptions import ChordError, ImproperlyConfigured
from celery.result import GroupResult, allow_join_result
from celery.utils.functional import dictfilter
from celery.utils.log import get_logger
from celery.utils.time import humanize_seconds
Expand Down Expand Up @@ -401,12 +402,14 @@ def _unpack_chord_result(self, tup, decode,
return retval

def apply_chord(self, header_result, body, **kwargs):
# Overrides this to avoid calling GroupResult.save
# pylint: disable=method-hidden
# Note that KeyValueStoreBackend.__init__ sets self.apply_chord
# if the implements_incr attr is set. Redis backend doesn't set
# this flag.
pass
# If any of the child results of this chord are complex (ie. group
# results themselves), we need to save `header_result` to ensure that
# the expected structure is retained when we finish the chord and pass
# the results onward to the body in `on_chord_part_return()`. We don't
# do this is all cases to retain an optimisation in the common case
# where a chord header is comprised of simple result objects.
if any(isinstance(nr, GroupResult) for nr in header_result.results):
header_result.save(backend=self)
maybe-sybr marked this conversation as resolved.
Show resolved Hide resolved

@cached_property
def _chord_zset(self):
Expand Down Expand Up @@ -449,27 +452,51 @@ def on_chord_part_return(self, request, state, result,
callback = maybe_signature(request.chord, app=app)
total = callback['chord_size'] + totaldiff
if readycount == total:
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
if self._chord_zset:
pipeline = pipe.zrange(jkey, 0, -1)
else:
pipeline = pipe.lrange(jkey, 0, total)
resl, = pipeline.execute()
try:
callback.delay([unpack(tup, decode) for tup in resl])
header_result = GroupResult.restore(gid)
if header_result is not None:
# If we manage to restore a `GroupResult`, then it must
# have been complex and saved by `apply_chord()` earlier.
#
# Before we can join the `GroupResult`, it needs to be
# manually marked as ready to avoid blocking
header_result.on_ready()
# We'll `join()` it to get the results and ensure they are
# structured as intended rather than the flattened version
# we'd construct without any other information.
join_func = (
header_result.join_native
if header_result.supports_native_join
else header_result.join
)
with allow_join_result():
resl = join_func(timeout=3.0, propagate=True)
else:
# Otherwise simply extract and decode the results we
# stashed along the way, which should be faster for large
# numbers of simple results in the chord header.
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
_, _ = pipe \
.delete(jkey) \
.delete(tkey) \
.execute()
if self._chord_zset:
pipeline = pipe.zrange(jkey, 0, -1)
else:
pipeline = pipe.lrange(jkey, 0, total)
resl, = pipeline.execute()
resl = [unpack(tup, decode) for tup in resl]
try:
callback.delay(resl)
except Exception as exc: # pylint: disable=broad-except
logger.exception(
'Chord callback for %r raised: %r', request.group, exc)
return self.chord_error_from_stack(
callback,
ChordError(f'Callback error: {exc!r}'),
)
finally:
with client.pipeline() as pipe:
_, _ = pipe \
.delete(jkey) \
.delete(tkey) \
.execute()
except ChordError as exc:
logger.exception('Chord %r raised: %r', request.group, exc)
return self.chord_error_from_stack(callback, exc)
Expand Down
58 changes: 39 additions & 19 deletions celery/canvas.py
Expand Up @@ -122,6 +122,9 @@ class Signature(dict):

TYPES = {}
_app = _type = None
# The following fields must not be changed during freezing/merging because
# to do so would disrupt completion of parent tasks
_IMMUTABLE_OPTIONS = {"group_id"}

@classmethod
def register_type(cls, name=None):
Expand Down Expand Up @@ -224,14 +227,22 @@ def apply_async(self, args=None, kwargs=None, route_name=None, **options):
def _merge(self, args=None, kwargs=None, options=None, force=False):
args = args if args else ()
kwargs = kwargs if kwargs else {}
options = options if options else {}
if options is not None:
# We build a new options dictionary where values in `options`
# override values in `self.options` except for keys which are
# noted as being immutable (unrelated to signature immutability)
# implying that allowing their value to change would stall tasks
new_options = dict(self.options, **{
k: v for k, v in options.items()
if k not in self._IMMUTABLE_OPTIONS or k not in self.options
})
else:
new_options = self.options
if self.immutable and not force:
return (self.args, self.kwargs,
dict(self.options,
**options) if options else self.options)
return (self.args, self.kwargs, new_options)
return (tuple(args) + tuple(self.args) if args else self.args,
dict(self.kwargs, **kwargs) if kwargs else self.kwargs,
dict(self.options, **options) if options else self.options)
new_options)

def clone(self, args=None, kwargs=None, **opts):
"""Create a copy of this signature.
Expand Down Expand Up @@ -286,7 +297,7 @@ def freeze(self, _id=None, group_id=None, chord=None,
opts['parent_id'] = parent_id
if 'reply_to' not in opts:
opts['reply_to'] = self.app.oid
if group_id:
if group_id and "group_id" not in opts:
opts['group_id'] = group_id
if chord:
opts['chord'] = chord
Expand Down Expand Up @@ -1372,21 +1383,30 @@ def apply(self, args=None, kwargs=None,
args=(tasks.apply(args, kwargs).get(propagate=propagate),),
)

def _traverse_tasks(self, tasks, value=None):
stack = deque(tasks)
while stack:
task = stack.popleft()
if isinstance(task, group):
stack.extend(task.tasks)
elif isinstance(task, _chain) and isinstance(task.tasks[-1], group):
stack.extend(task.tasks[-1].tasks)
else:
yield task if value is None else value
@classmethod
def __descend(cls, sig_obj):
# Sometimes serialized signatures might make their way here
if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict):
sig_obj = Signature.from_dict(sig_obj)
maybe-sybr marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(sig_obj, group):
# Each task in a group counts toward this chord
subtasks = getattr(sig_obj.tasks, "tasks", sig_obj.tasks)
return sum(cls.__descend(task) for task in subtasks)
elif isinstance(sig_obj, _chain):
# The last element in a chain counts toward this chord
return cls.__descend(sig_obj.tasks[-1])
elif isinstance(sig_obj, chord):
# The child chord's body counts toward this chord
return cls.__descend(sig_obj.body)
elif isinstance(sig_obj, Signature):
# Each simple signature counts as 1 completion for this chord
return 1
# Any other types are assumed to be iterables of simple signatures
return len(sig_obj)
maybe-sybr marked this conversation as resolved.
Show resolved Hide resolved

def __length_hint__(self):
tasks = (self.tasks.tasks if isinstance(self.tasks, group)
else self.tasks)
return sum(self._traverse_tasks(tasks, 1))
tasks = getattr(self.tasks, "tasks", self.tasks)
return sum(self.__descend(task) for task in tasks)

def run(self, header, body, partial_args, app=None, interval=None,
countdown=1, max_retries=None, eager=False,
Expand Down
112 changes: 99 additions & 13 deletions t/integration/test_canvas.py
@@ -1,12 +1,12 @@
import os
import re
from datetime import datetime, timedelta
from time import sleep

import pytest

from celery import chain, chord, group, signature
from celery.backends.base import BaseKeyValueStoreBackend
from celery.exceptions import ChordError, TimeoutError
from celery.exceptions import TimeoutError
from celery.result import AsyncResult, GroupResult, ResultSet

from .conftest import get_active_redis_channels, get_redis_connection
Expand Down Expand Up @@ -423,6 +423,34 @@ def test_nested_chain_group_lone(self, manager):
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]

def test_nested_chain_group_mid(self, manager):
"""
Test that a mid-point group in a chain completes.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chain(
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
identity.s(), # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]

def test_nested_chain_group_last(self, manager):
"""
Test that a final group in a chain with preceding tasks completes.
"""
sig = chain(
identity.s(42), # 42
group(identity.s(), identity.s()), # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]


class test_result_set:

Expand Down Expand Up @@ -522,6 +550,16 @@ def test_group_lone(self, manager):
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]

def test_nested_group_group(self, manager):
"""
Confirm that groups nested inside groups get unrolled.
"""
sig = group(
group(identity.s(42), identity.s(42)), # [42, 42]
) # [42, 42] due to unrolling
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]


def assert_ids(r, expected_value, expected_root_id, expected_parent_id):
root_id, parent_id, value = r.get(timeout=TIMEOUT)
Expand Down Expand Up @@ -653,10 +691,12 @@ def test_eager_chord_inside_task(self, manager):

chord_add.app.conf.task_always_eager = prev

@flaky
def test_group_chain(self, manager):
if not manager.app.conf.result_backend.startswith('redis'):
raise pytest.skip('Requires redis result backend.')
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

c = (
add.s(2, 2) |
group(add.s(i) for i in range(4)) |
Expand All @@ -665,11 +705,6 @@ def test_group_chain(self, manager):
res = c()
assert res.get(timeout=TIMEOUT) == [12, 13, 14, 15]

@flaky
@pytest.mark.xfail(os.environ['TEST_BACKEND'] == 'cache+pylibmc://',
reason="Not supported yet by the cache backend.",
strict=True,
raises=ChordError)
def test_nested_group_chain(self, manager):
try:
manager.app.backend.ensure_chords_allowed()
Expand Down Expand Up @@ -858,9 +893,15 @@ def test_chord_on_error(self, manager):
# So for clarity of our test, we instead do it here.

# Use the error callback's result to find the failed task.
error_callback_result = AsyncResult(
res.children[0].children[0].result[0])
failed_task_id = error_callback_result.result.args[0].split()[3]
uuid_patt = re.compile(
r"[0-9A-Fa-f]{8}-([0-9A-Fa-f]{4}-){3}[0-9A-Fa-f]{12}"
)
callback_chord_exc = AsyncResult(
res.children[0].children[0].result[0]
).result
failed_task_id = uuid_patt.search(str(callback_chord_exc))
assert (failed_task_id is not None), "No task ID in %r" % callback_exc
failed_task_id = failed_task_id.group()

# Use new group_id result metadata to get group ID.
failed_task_result = AsyncResult(failed_task_id)
Expand Down Expand Up @@ -1009,3 +1050,48 @@ def test_priority_chain(self, manager):
c = return_priority.signature(priority=3) | return_priority.signature(
priority=5)
assert c().get(timeout=TIMEOUT) == "Priority: 5"

def test_nested_chord_group(self, manager):
"""
Confirm that groups nested inside chords get unrolled.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chord(
(
group(identity.s(42), identity.s(42)), # [42, 42]
),
identity.s() # [42, 42]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [42, 42]

def test_nested_chord_group_chain_group_tail(self, manager):
"""
Sanity check that a deeply nested group is completed as expected.

Groups at the end of chains nested in chords have had issues and this
simple test sanity check that such a tsk structure can be completed.
"""
try:
manager.app.backend.ensure_chords_allowed()
except NotImplementedError as e:
raise pytest.skip(e.args[0])

sig = chord(
group(
chain(
identity.s(42), # 42
group(
identity.s(), # 42
identity.s(), # 42
), # [42, 42]
), # [42, 42]
), # [[42, 42]] since the chain prevents unrolling
identity.s(), # [[42, 42]]
)
res = sig.delay()
assert res.get(timeout=TIMEOUT) == [[42, 42]]