Skip to content

Commit

Permalink
fix: Retain chord header result structure in Redis
Browse files Browse the repository at this point in the history
This change fixes the chord result flattening issue which manifested
when using the Redis backend due to its deliberate throwing away of
information about the header result structure. Rather than assuming that
all results which contribute to the finalisation of a chord should be
siblings, this change checks if any are complex (ie. `GroupResult`s) and
falls back to behaviour similar to that implemented in the
`KeyValueStoreBackend` which restores the original `GroupResult` object
and `join()`s it.

We retain the original behaviour which is billed as an optimisation in
f09b041. We could behave better in the complex header result case by not
bothering to stash the results of contributing tasks under the `.j` zset
since we won't be using them, but without checking for the presence of
the complex group result on every `on_chord_part_return()` call, we
can't be sure that we won't need those stashed results later on. This
would be an opportunity for optimisation in future if we were to use an
`EVAL` to only do the `zadd()` if the group result key doesn't exist.
However, avoiding the result encoding work in `on_chord_part_return()`
would be more complicated. For now, it's not worth the brainpower.

This change also slightly refactors the redis backend unit tests to make
it easier to build fixtures and hit both the complex and simple result
structure cases.
  • Loading branch information
maybe-sybr committed Oct 14, 2020
1 parent f8ab428 commit 8422937
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 77 deletions.
65 changes: 46 additions & 19 deletions celery/backends/redis.py
Expand Up @@ -13,6 +13,7 @@
from celery._state import task_join_will_block
from celery.canvas import maybe_signature
from celery.exceptions import ChordError, ImproperlyConfigured
from celery.result import GroupResult, allow_join_result
from celery.utils.functional import dictfilter
from celery.utils.log import get_logger
from celery.utils.time import humanize_seconds
Expand Down Expand Up @@ -401,12 +402,14 @@ def _unpack_chord_result(self, tup, decode,
return retval

def apply_chord(self, header_result, body, **kwargs):
# Overrides this to avoid calling GroupResult.save
# pylint: disable=method-hidden
# Note that KeyValueStoreBackend.__init__ sets self.apply_chord
# if the implements_incr attr is set. Redis backend doesn't set
# this flag.
pass
# If any of the child results of this chord are complex (ie. group
# results themselves), we need to save `header_result` to ensure that
# the expected structure is retained when we finish the chord and pass
# the results onward to the body in `on_chord_part_return()`. We don't
# do this is all cases to retain an optimisation in the common case
# where a chord header is comprised of simple result objects.
if any(isinstance(nr, GroupResult) for nr in header_result.results):
header_result.save(backend=self)

@cached_property
def _chord_zset(self):
Expand Down Expand Up @@ -449,27 +452,51 @@ def on_chord_part_return(self, request, state, result,
callback = maybe_signature(request.chord, app=app)
total = callback['chord_size'] + totaldiff
if readycount == total:
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
if self._chord_zset:
pipeline = pipe.zrange(jkey, 0, -1)
else:
pipeline = pipe.lrange(jkey, 0, total)
resl, = pipeline.execute()
try:
callback.delay([unpack(tup, decode) for tup in resl])
header_result = GroupResult.restore(gid)
if header_result is not None:
# If we manage to restore a `GroupResult`, then it must
# have been complex and saved by `apply_chord()` earlier.
#
# Before we can join the `GroupResult`, it needs to be
# manually marked as ready to avoid blocking
header_result.on_ready()
# We'll `join()` it to get the results and ensure they are
# structured as intended rather than the flattened version
# we'd construct without any other information.
join_func = (
header_result.join_native
if header_result.supports_native_join
else header_result.join
)
with allow_join_result():
resl = join_func(timeout=3.0, propagate=True)
else:
# Otherwise simply extract and decode the results we
# stashed along the way, which should be faster for large
# numbers of simple results in the chord header.
decode, unpack = self.decode, self._unpack_chord_result
with client.pipeline() as pipe:
_, _ = pipe \
.delete(jkey) \
.delete(tkey) \
.execute()
if self._chord_zset:
pipeline = pipe.zrange(jkey, 0, -1)
else:
pipeline = pipe.lrange(jkey, 0, total)
resl, = pipeline.execute()
resl = [unpack(tup, decode) for tup in resl]
try:
callback.delay(resl)
except Exception as exc: # pylint: disable=broad-except
logger.exception(
'Chord callback for %r raised: %r', request.group, exc)
return self.chord_error_from_stack(
callback,
ChordError(f'Callback error: {exc!r}'),
)
finally:
with client.pipeline() as pipe:
_, _ = pipe \
.delete(jkey) \
.delete(tkey) \
.execute()
except ChordError as exc:
logger.exception('Chord %r raised: %r', request.group, exc)
return self.chord_error_from_stack(callback, exc)
Expand Down
13 changes: 10 additions & 3 deletions t/integration/test_canvas.py
@@ -1,3 +1,4 @@
import re
from datetime import datetime, timedelta
from time import sleep

Expand Down Expand Up @@ -892,9 +893,15 @@ def test_chord_on_error(self, manager):
# So for clarity of our test, we instead do it here.

# Use the error callback's result to find the failed task.
error_callback_result = AsyncResult(
res.children[0].children[0].result[0])
failed_task_id = error_callback_result.result.args[0].split()[3]
uuid_patt = re.compile(
r"[0-9A-Fa-f]{8}-([0-9A-Fa-f]{4}-){3}[0-9A-Fa-f]{12}"
)
callback_chord_exc = AsyncResult(
res.children[0].children[0].result[0]
).result
failed_task_id = uuid_patt.search(str(callback_chord_exc))
assert (failed_task_id is not None), "No task ID in %r" % callback_exc
failed_task_id = failed_task_id.group()

# Use new group_id result metadata to get group ID.
failed_task_result = AsyncResult(failed_task_id)
Expand Down
193 changes: 138 additions & 55 deletions t/unit/backends/test_redis.py
@@ -1,3 +1,4 @@
import itertools
import json
import random
import ssl
Expand Down Expand Up @@ -274,7 +275,7 @@ def test_drain_events_connection_error(self, parent_on_state_change, cancel_for)
assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'}


class test_RedisBackend:
class basetest_RedisBackend:
def get_backend(self):
from celery.backends.redis import RedisBackend

Expand All @@ -287,11 +288,42 @@ def get_E_LOST(self):
from celery.backends.redis import E_LOST
return E_LOST

def create_task(self, i, group_id="group_id"):
tid = uuid()
task = Mock(name=f'task-{tid}')
task.name = 'foobarbaz'
self.app.tasks['foobarbaz'] = task
task.request.chord = signature(task)
task.request.id = tid
task.request.chord['chord_size'] = 10
task.request.group = group_id
task.request.group_index = i
return task

@contextmanager
def chord_context(self, size=1):
with patch('celery.backends.redis.maybe_signature') as ms:
request = Mock(name='request')
request.id = 'id1'
request.group = 'gid1'
request.group_index = None
tasks = [
self.create_task(i, group_id=request.group)
for i in range(size)
]
callback = ms.return_value = Signature('add')
callback.id = 'id1'
callback['chord_size'] = size
callback.delay = Mock(name='callback.delay')
yield tasks, request, callback

def setup(self):
self.Backend = self.get_backend()
self.E_LOST = self.get_E_LOST()
self.b = self.Backend(app=self.app)


class test_RedisBackend(basetest_RedisBackend):
@pytest.mark.usefixtures('depends_on_current_app')
def test_reduce(self):
pytest.importorskip('redis')
Expand Down Expand Up @@ -623,20 +655,36 @@ def test_set_no_expire(self):
self.b.expires = None
self.b._set_with_state('foo', 'bar', states.SUCCESS)

def create_task(self, i):
def test_process_cleanup(self):
self.b.process_cleanup()

def test_get_set_forget(self):
tid = uuid()
task = Mock(name=f'task-{tid}')
task.name = 'foobarbaz'
self.app.tasks['foobarbaz'] = task
task.request.chord = signature(task)
task.request.id = tid
task.request.chord['chord_size'] = 10
task.request.group = 'group_id'
task.request.group_index = i
return task
self.b.store_result(tid, 42, states.SUCCESS)
assert self.b.get_state(tid) == states.SUCCESS
assert self.b.get_result(tid) == 42
self.b.forget(tid)
assert self.b.get_state(tid) == states.PENDING

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return(self, restore):
def test_set_expires(self):
self.b = self.Backend(expires=512, app=self.app)
tid = uuid()
key = self.b.get_key_for_task(tid)
self.b.store_result(tid, 42, states.SUCCESS)
self.b.client.expire.assert_called_with(
key, 512,
)


class test_RedisBackend_chords_simple(basetest_RedisBackend):
@pytest.fixture(scope="class", autouse=True)
def simple_header_result(self):
with patch(
"celery.result.GroupResult.restore", return_value=None,
) as p:
yield p

def test_on_chord_part_return(self):
tasks = [self.create_task(i) for i in range(10)]
random.shuffle(tasks)

Expand All @@ -652,8 +700,7 @@ def test_on_chord_part_return(self, restore):
call(jkey, 86400), call(tkey, 86400),
])

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return__unordered(self, restore):
def test_on_chord_part_return__unordered(self):
self.app.conf.result_backend_transport_options = dict(
result_chord_ordered=False,
)
Expand All @@ -673,8 +720,7 @@ def test_on_chord_part_return__unordered(self, restore):
call(jkey, 86400), call(tkey, 86400),
])

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return__ordered(self, restore):
def test_on_chord_part_return__ordered(self):
self.app.conf.result_backend_transport_options = dict(
result_chord_ordered=True,
)
Expand All @@ -694,8 +740,7 @@ def test_on_chord_part_return__ordered(self, restore):
call(jkey, 86400), call(tkey, 86400),
])

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return_no_expiry(self, restore):
def test_on_chord_part_return_no_expiry(self):
old_expires = self.b.expires
self.b.expires = None
tasks = [self.create_task(i) for i in range(10)]
Expand All @@ -712,8 +757,7 @@ def test_on_chord_part_return_no_expiry(self, restore):

self.b.expires = old_expires

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return_expire_set_to_zero(self, restore):
def test_on_chord_part_return_expire_set_to_zero(self):
old_expires = self.b.expires
self.b.expires = 0
tasks = [self.create_task(i) for i in range(10)]
Expand All @@ -730,8 +774,7 @@ def test_on_chord_part_return_expire_set_to_zero(self, restore):

self.b.expires = old_expires

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return_no_expiry__unordered(self, restore):
def test_on_chord_part_return_no_expiry__unordered(self):
self.app.conf.result_backend_transport_options = dict(
result_chord_ordered=False,
)
Expand All @@ -752,8 +795,7 @@ def test_on_chord_part_return_no_expiry__unordered(self, restore):

self.b.expires = old_expires

@patch('celery.result.GroupResult.restore')
def test_on_chord_part_return_no_expiry__ordered(self, restore):
def test_on_chord_part_return_no_expiry__ordered(self):
self.app.conf.result_backend_transport_options = dict(
result_chord_ordered=True,
)
Expand Down Expand Up @@ -926,39 +968,80 @@ def test_on_chord_part_return__other_error__ordered(self):
callback.id, exc=ANY,
)

@contextmanager
def chord_context(self, size=1):
with patch('celery.backends.redis.maybe_signature') as ms:
tasks = [self.create_task(i) for i in range(size)]
request = Mock(name='request')
request.id = 'id1'
request.group = 'gid1'
request.group_index = None
callback = ms.return_value = Signature('add')
callback.id = 'id1'
callback['chord_size'] = size
callback.delay = Mock(name='callback.delay')
yield tasks, request, callback

def test_process_cleanup(self):
self.b.process_cleanup()
class test_RedisBackend_chords_complex(basetest_RedisBackend):
@pytest.fixture(scope="function", autouse=True)
def complex_header_result(self):
with patch("celery.result.GroupResult.restore") as p:
yield p

def test_apply_chord_complex_header(self):
mock_header_result = Mock()
# No results in the header at all - won't call `save()`
mock_header_result.results = tuple()
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_not_called()
mock_header_result.save.reset_mock()
# A single simple result in the header - won't call `save()`
mock_header_result.results = (self.app.AsyncResult("foo"), )
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_not_called()
mock_header_result.save.reset_mock()
# Many simple results in the header - won't call `save()`
mock_header_result.results = (self.app.AsyncResult("foo"), ) * 42
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_not_called()
mock_header_result.save.reset_mock()
# A single complex result in the header - will call `save()`
mock_header_result.results = (self.app.GroupResult("foo"), )
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_called_once_with(backend=self.b)
mock_header_result.save.reset_mock()
# Many complex results in the header - will call `save()`
mock_header_result.results = (self.app.GroupResult("foo"), ) * 42
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_called_once_with(backend=self.b)
mock_header_result.save.reset_mock()
# Mixed simple and complex results in the header - will call `save()`
mock_header_result.results = itertools.islice(
itertools.cycle((
self.app.AsyncResult("foo"), self.app.GroupResult("foo"),
)), 42,
)
self.b.apply_chord(mock_header_result, None)
mock_header_result.save.assert_called_once_with(backend=self.b)
mock_header_result.save.reset_mock()

def test_get_set_forget(self):
tid = uuid()
self.b.store_result(tid, 42, states.SUCCESS)
assert self.b.get_state(tid) == states.SUCCESS
assert self.b.get_result(tid) == 42
self.b.forget(tid)
assert self.b.get_state(tid) == states.PENDING
@pytest.mark.parametrize("supports_native_join", (True, False))
def test_on_chord_part_return(
self, complex_header_result, supports_native_join,
):
mock_result_obj = complex_header_result.return_value
mock_result_obj.supports_native_join = supports_native_join

def test_set_expires(self):
self.b = self.Backend(expires=512, app=self.app)
tid = uuid()
key = self.b.get_key_for_task(tid)
self.b.store_result(tid, 42, states.SUCCESS)
self.b.client.expire.assert_called_with(
key, 512,
)
tasks = [self.create_task(i) for i in range(10)]
random.shuffle(tasks)

with self.chord_context(10) as (tasks, request, callback):
for task, result_val in zip(tasks, itertools.cycle((42, ))):
self.b.on_chord_part_return(
task.request, states.SUCCESS, result_val,
)
# Confirm that `zadd` was called even though we won't end up
# using the data pushed into the sorted set
assert self.b.client.zadd.call_count == 1
self.b.client.zadd.reset_mock()
# Confirm that neither `zrange` not `lrange` were called
self.b.client.zrange.assert_not_called()
self.b.client.lrange.assert_not_called()
# Confirm that the `GroupResult.restore` mock was called
complex_header_result.assert_called_once_with(request.group)
# Confirm the the callback was called with the `join()`ed group result
if supports_native_join:
expected_join = mock_result_obj.join_native
else:
expected_join = mock_result_obj.join
callback.delay.assert_called_once_with(expected_join())


class test_SentinelBackend:
Expand Down

0 comments on commit 8422937

Please sign in to comment.