diff --git a/celery/backends/redis.py b/celery/backends/redis.py index 2c42882353..dd3677f569 100644 --- a/celery/backends/redis.py +++ b/celery/backends/redis.py @@ -13,6 +13,7 @@ from celery._state import task_join_will_block from celery.canvas import maybe_signature from celery.exceptions import ChordError, ImproperlyConfigured +from celery.result import GroupResult, allow_join_result from celery.utils.functional import dictfilter from celery.utils.log import get_logger from celery.utils.time import humanize_seconds @@ -401,12 +402,14 @@ def _unpack_chord_result(self, tup, decode, return retval def apply_chord(self, header_result, body, **kwargs): - # Overrides this to avoid calling GroupResult.save - # pylint: disable=method-hidden - # Note that KeyValueStoreBackend.__init__ sets self.apply_chord - # if the implements_incr attr is set. Redis backend doesn't set - # this flag. - pass + # If any of the child results of this chord are complex (ie. group + # results themselves), we need to save `header_result` to ensure that + # the expected structure is retained when we finish the chord and pass + # the results onward to the body in `on_chord_part_return()`. We don't + # do this is all cases to retain an optimisation in the common case + # where a chord header is comprised of simple result objects. + if any(isinstance(nr, GroupResult) for nr in header_result.results): + header_result.save(backend=self) @cached_property def _chord_zset(self): @@ -449,20 +452,38 @@ def on_chord_part_return(self, request, state, result, callback = maybe_signature(request.chord, app=app) total = callback['chord_size'] + totaldiff if readycount == total: - decode, unpack = self.decode, self._unpack_chord_result - with client.pipeline() as pipe: - if self._chord_zset: - pipeline = pipe.zrange(jkey, 0, -1) - else: - pipeline = pipe.lrange(jkey, 0, total) - resl, = pipeline.execute() - try: - callback.delay([unpack(tup, decode) for tup in resl]) + header_result = GroupResult.restore(gid) + if header_result is not None: + # If we manage to restore a `GroupResult`, then it must + # have been complex and saved by `apply_chord()` earlier. + # + # Before we can join the `GroupResult`, it needs to be + # manually marked as ready to avoid blocking + header_result.on_ready() + # We'll `join()` it to get the results and ensure they are + # structured as intended rather than the flattened version + # we'd construct without any other information. + join_func = ( + header_result.join_native + if header_result.supports_native_join + else header_result.join + ) + with allow_join_result(): + resl = join_func(timeout=3.0, propagate=True) + else: + # Otherwise simply extract and decode the results we + # stashed along the way, which should be faster for large + # numbers of simple results in the chord header. + decode, unpack = self.decode, self._unpack_chord_result with client.pipeline() as pipe: - _, _ = pipe \ - .delete(jkey) \ - .delete(tkey) \ - .execute() + if self._chord_zset: + pipeline = pipe.zrange(jkey, 0, -1) + else: + pipeline = pipe.lrange(jkey, 0, total) + resl, = pipeline.execute() + resl = [unpack(tup, decode) for tup in resl] + try: + callback.delay(resl) except Exception as exc: # pylint: disable=broad-except logger.exception( 'Chord callback for %r raised: %r', request.group, exc) @@ -470,6 +491,12 @@ def on_chord_part_return(self, request, state, result, callback, ChordError(f'Callback error: {exc!r}'), ) + finally: + with client.pipeline() as pipe: + _, _ = pipe \ + .delete(jkey) \ + .delete(tkey) \ + .execute() except ChordError as exc: logger.exception('Chord %r raised: %r', request.group, exc) return self.chord_error_from_stack(callback, exc) diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 48019f0342..3343acf081 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -1,3 +1,4 @@ +import re from datetime import datetime, timedelta from time import sleep @@ -874,9 +875,15 @@ def test_chord_on_error(self, manager): # So for clarity of our test, we instead do it here. # Use the error callback's result to find the failed task. - error_callback_result = AsyncResult( - res.children[0].children[0].result[0]) - failed_task_id = error_callback_result.result.args[0].split()[3] + uuid_patt = re.compile( + r"[0-9A-Fa-f]{8}-([0-9A-Fa-f]{4}-){3}[0-9A-Fa-f]{12}" + ) + callback_chord_exc = AsyncResult( + res.children[0].children[0].result[0] + ).result + failed_task_id = uuid_patt.search(str(callback_chord_exc)) + assert (failed_task_id is not None), "No task ID in %r" % callback_exc + failed_task_id = failed_task_id.group() # Use new group_id result metadata to get group ID. failed_task_result = AsyncResult(failed_task_id) diff --git a/t/unit/backends/test_redis.py b/t/unit/backends/test_redis.py index 2029edc3c2..f534077a4f 100644 --- a/t/unit/backends/test_redis.py +++ b/t/unit/backends/test_redis.py @@ -1,3 +1,4 @@ +import itertools import json import random import ssl @@ -274,7 +275,7 @@ def test_drain_events_connection_error(self, parent_on_state_change, cancel_for) assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'} -class test_RedisBackend: +class basetest_RedisBackend: def get_backend(self): from celery.backends.redis import RedisBackend @@ -287,11 +288,42 @@ def get_E_LOST(self): from celery.backends.redis import E_LOST return E_LOST + def create_task(self, i, group_id="group_id"): + tid = uuid() + task = Mock(name=f'task-{tid}') + task.name = 'foobarbaz' + self.app.tasks['foobarbaz'] = task + task.request.chord = signature(task) + task.request.id = tid + task.request.chord['chord_size'] = 10 + task.request.group = group_id + task.request.group_index = i + return task + + @contextmanager + def chord_context(self, size=1): + with patch('celery.backends.redis.maybe_signature') as ms: + request = Mock(name='request') + request.id = 'id1' + request.group = 'gid1' + request.group_index = None + tasks = [ + self.create_task(i, group_id=request.group) + for i in range(size) + ] + callback = ms.return_value = Signature('add') + callback.id = 'id1' + callback['chord_size'] = size + callback.delay = Mock(name='callback.delay') + yield tasks, request, callback + def setup(self): self.Backend = self.get_backend() self.E_LOST = self.get_E_LOST() self.b = self.Backend(app=self.app) + +class test_RedisBackend(basetest_RedisBackend): @pytest.mark.usefixtures('depends_on_current_app') def test_reduce(self): pytest.importorskip('redis') @@ -623,20 +655,36 @@ def test_set_no_expire(self): self.b.expires = None self.b._set_with_state('foo', 'bar', states.SUCCESS) - def create_task(self, i): + def test_process_cleanup(self): + self.b.process_cleanup() + + def test_get_set_forget(self): tid = uuid() - task = Mock(name=f'task-{tid}') - task.name = 'foobarbaz' - self.app.tasks['foobarbaz'] = task - task.request.chord = signature(task) - task.request.id = tid - task.request.chord['chord_size'] = 10 - task.request.group = 'group_id' - task.request.group_index = i - return task + self.b.store_result(tid, 42, states.SUCCESS) + assert self.b.get_state(tid) == states.SUCCESS + assert self.b.get_result(tid) == 42 + self.b.forget(tid) + assert self.b.get_state(tid) == states.PENDING - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return(self, restore): + def test_set_expires(self): + self.b = self.Backend(expires=512, app=self.app) + tid = uuid() + key = self.b.get_key_for_task(tid) + self.b.store_result(tid, 42, states.SUCCESS) + self.b.client.expire.assert_called_with( + key, 512, + ) + + +class test_RedisBackend_chords_simple(basetest_RedisBackend): + @pytest.fixture(scope="class", autouse=True) + def simple_header_result(self): + with patch( + "celery.result.GroupResult.restore", return_value=None, + ) as p: + yield p + + def test_on_chord_part_return(self): tasks = [self.create_task(i) for i in range(10)] random.shuffle(tasks) @@ -652,8 +700,7 @@ def test_on_chord_part_return(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return__unordered(self, restore): + def test_on_chord_part_return__unordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=False, ) @@ -673,8 +720,7 @@ def test_on_chord_part_return__unordered(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return__ordered(self, restore): + def test_on_chord_part_return__ordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=True, ) @@ -694,8 +740,7 @@ def test_on_chord_part_return__ordered(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry(self, restore): + def test_on_chord_part_return_no_expiry(self): old_expires = self.b.expires self.b.expires = None tasks = [self.create_task(i) for i in range(10)] @@ -712,8 +757,7 @@ def test_on_chord_part_return_no_expiry(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_expire_set_to_zero(self, restore): + def test_on_chord_part_return_expire_set_to_zero(self): old_expires = self.b.expires self.b.expires = 0 tasks = [self.create_task(i) for i in range(10)] @@ -730,8 +774,7 @@ def test_on_chord_part_return_expire_set_to_zero(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry__unordered(self, restore): + def test_on_chord_part_return_no_expiry__unordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=False, ) @@ -752,8 +795,7 @@ def test_on_chord_part_return_no_expiry__unordered(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry__ordered(self, restore): + def test_on_chord_part_return_no_expiry__ordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=True, ) @@ -926,39 +968,80 @@ def test_on_chord_part_return__other_error__ordered(self): callback.id, exc=ANY, ) - @contextmanager - def chord_context(self, size=1): - with patch('celery.backends.redis.maybe_signature') as ms: - tasks = [self.create_task(i) for i in range(size)] - request = Mock(name='request') - request.id = 'id1' - request.group = 'gid1' - request.group_index = None - callback = ms.return_value = Signature('add') - callback.id = 'id1' - callback['chord_size'] = size - callback.delay = Mock(name='callback.delay') - yield tasks, request, callback - def test_process_cleanup(self): - self.b.process_cleanup() +class test_RedisBackend_chords_complex(basetest_RedisBackend): + @pytest.fixture(scope="function", autouse=True) + def complex_header_result(self): + with patch("celery.result.GroupResult.restore") as p: + yield p + + def test_apply_chord_complex_header(self): + mock_header_result = Mock() + # No results in the header at all - won't call `save()` + mock_header_result.results = tuple() + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_not_called() + mock_header_result.save.reset_mock() + # A single simple result in the header - won't call `save()` + mock_header_result.results = (self.app.AsyncResult("foo"), ) + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_not_called() + mock_header_result.save.reset_mock() + # Many simple results in the header - won't call `save()` + mock_header_result.results = (self.app.AsyncResult("foo"), ) * 42 + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_not_called() + mock_header_result.save.reset_mock() + # A single complex result in the header - will call `save()` + mock_header_result.results = (self.app.GroupResult("foo"), ) + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_called_once_with(backend=self.b) + mock_header_result.save.reset_mock() + # Many complex results in the header - will call `save()` + mock_header_result.results = (self.app.GroupResult("foo"), ) * 42 + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_called_once_with(backend=self.b) + mock_header_result.save.reset_mock() + # Mixed simple and complex results in the header - will call `save()` + mock_header_result.results = itertools.islice( + itertools.cycle(( + self.app.AsyncResult("foo"), self.app.GroupResult("foo"), + )), 42, + ) + self.b.apply_chord(mock_header_result, None) + mock_header_result.save.assert_called_once_with(backend=self.b) + mock_header_result.save.reset_mock() - def test_get_set_forget(self): - tid = uuid() - self.b.store_result(tid, 42, states.SUCCESS) - assert self.b.get_state(tid) == states.SUCCESS - assert self.b.get_result(tid) == 42 - self.b.forget(tid) - assert self.b.get_state(tid) == states.PENDING + @pytest.mark.parametrize("supports_native_join", (True, False)) + def test_on_chord_part_return( + self, complex_header_result, supports_native_join, + ): + mock_result_obj = complex_header_result.return_value + mock_result_obj.supports_native_join = supports_native_join - def test_set_expires(self): - self.b = self.Backend(expires=512, app=self.app) - tid = uuid() - key = self.b.get_key_for_task(tid) - self.b.store_result(tid, 42, states.SUCCESS) - self.b.client.expire.assert_called_with( - key, 512, - ) + tasks = [self.create_task(i) for i in range(10)] + random.shuffle(tasks) + + with self.chord_context(10) as (tasks, request, callback): + for task, result_val in zip(tasks, itertools.cycle((42, ))): + self.b.on_chord_part_return( + task.request, states.SUCCESS, result_val, + ) + # Confirm that `zadd` was called even though we won't end up + # using the data pushed into the sorted set + assert self.b.client.zadd.call_count == 1 + self.b.client.zadd.reset_mock() + # Confirm that neither `zrange` not `lrange` were called + self.b.client.zrange.assert_not_called() + self.b.client.lrange.assert_not_called() + # Confirm that the `GroupResult.restore` mock was called + complex_header_result.assert_called_once_with(request.group) + # Confirm the the callback was called with the `join()`ed group result + if supports_native_join: + expected_join = mock_result_obj.join_native + else: + expected_join = mock_result_obj.join + callback.delay.assert_called_once_with(expected_join()) class test_SentinelBackend: