From 42a3f466bae1cb5f49b781fd4ada5b10d637be92 Mon Sep 17 00:00:00 2001 From: maybe-sybr <58414429+maybe-sybr@users.noreply.github.com> Date: Wed, 14 Oct 2020 10:39:19 +1100 Subject: [PATCH] fix: Retain chord header result structure in Redis This change fixes the chord result flattening issue which manifested when using the Redis backend due to its deliberate throwing away of information about the header result structure. Rather than assuming that all results which contribute to the finalisation of a chord should be siblings, this change checks if any are complex (ie. `GroupResult`s) and falls back to behaviour similar to that implemented in the `KeyValueStoreBackend` which restores the original `GroupResult` object and `join()`s it. We retain the original behaviour which is billed as an optimisation in f09b041. We could behave better in the complex header result case by not bothering to stash the results of contributing tasks under the `.j` zset since we won't be using them, but without checking for the presence of the complex group result on every `on_chord_part_return()` call, we can't be sure that we won't need those stashed results later on. This would be an opportunity for optimisation in future if we were to use an `EVAL` to only do the `zadd()` if the group result key doesn't exist. However, avoiding the result encoding work in `on_chord_part_return()` would be more complicated. For now, it's not worth the brainpower. This change also slightly refactors the redis backend unit tests to make it easier to build fixtures and hit both the complex and simple result structure cases. --- celery/backends/redis.py | 61 ++++++++----- t/unit/backends/test_redis.py | 156 ++++++++++++++++++++++------------ 2 files changed, 143 insertions(+), 74 deletions(-) diff --git a/celery/backends/redis.py b/celery/backends/redis.py index 2c428823538..0aed3eefea6 100644 --- a/celery/backends/redis.py +++ b/celery/backends/redis.py @@ -13,6 +13,7 @@ from celery._state import task_join_will_block from celery.canvas import maybe_signature from celery.exceptions import ChordError, ImproperlyConfigured +from celery.result import GroupResult, allow_join_result from celery.utils.functional import dictfilter from celery.utils.log import get_logger from celery.utils.time import humanize_seconds @@ -401,12 +402,14 @@ def _unpack_chord_result(self, tup, decode, return retval def apply_chord(self, header_result, body, **kwargs): - # Overrides this to avoid calling GroupResult.save - # pylint: disable=method-hidden - # Note that KeyValueStoreBackend.__init__ sets self.apply_chord - # if the implements_incr attr is set. Redis backend doesn't set - # this flag. - pass + # If any of the child results of this chord are complex (ie. group + # results themselves), we need to save `header_result` to ensure that + # the expected structure is retained when we finish the chord and pass + # the results onward to the body in `on_chord_part_return()`. We don't + # do this is all cases to retain an optimisation in the common case + # where a chord header is comprised of simple result objects. + if any(isinstance(nr, GroupResult) for nr in header_result.results): + header_result.save(backend=self) @cached_property def _chord_zset(self): @@ -449,20 +452,34 @@ def on_chord_part_return(self, request, state, result, callback = maybe_signature(request.chord, app=app) total = callback['chord_size'] + totaldiff if readycount == total: - decode, unpack = self.decode, self._unpack_chord_result - with client.pipeline() as pipe: - if self._chord_zset: - pipeline = pipe.zrange(jkey, 0, -1) - else: - pipeline = pipe.lrange(jkey, 0, total) - resl, = pipeline.execute() - try: - callback.delay([unpack(tup, decode) for tup in resl]) + hr = GroupResult.restore(gid) + if hr is not None: + # If we manage to restore a `GroupResult`, then it must + # have been complex and saved by `apply_chord()` earlier. + # + # Before we can join the `GroupResult`, it needs to be + # manually marked as ready to avoid blocking + hr.on_ready() + # We'll `join()` it to get the results and ensure they are + # structured as intended rather than the flattened version + # we'd construct without any other information. + j = hr.join_native if hr.supports_native_join else hr.join + with allow_join_result(): + resl = j(timeout=3.0, propagate=True) + else: + # Otherwise simply extract and decode the results we + # stashed along the way, which should be faster for large + # numbers of simple results in the chord header. + decode, unpack = self.decode, self._unpack_chord_result with client.pipeline() as pipe: - _, _ = pipe \ - .delete(jkey) \ - .delete(tkey) \ - .execute() + if self._chord_zset: + pipeline = pipe.zrange(jkey, 0, -1) + else: + pipeline = pipe.lrange(jkey, 0, total) + resl, = pipeline.execute() + resl = [unpack(tup, decode) for tup in resl] + try: + callback.delay(resl) except Exception as exc: # pylint: disable=broad-except logger.exception( 'Chord callback for %r raised: %r', request.group, exc) @@ -470,6 +487,12 @@ def on_chord_part_return(self, request, state, result, callback, ChordError(f'Callback error: {exc!r}'), ) + finally: + with client.pipeline() as pipe: + _, _ = pipe \ + .delete(jkey) \ + .delete(tkey) \ + .execute() except ChordError as exc: logger.exception('Chord %r raised: %r', request.group, exc) return self.chord_error_from_stack(callback, exc) diff --git a/t/unit/backends/test_redis.py b/t/unit/backends/test_redis.py index 2029edc3c29..128d3c6c6b0 100644 --- a/t/unit/backends/test_redis.py +++ b/t/unit/backends/test_redis.py @@ -1,3 +1,4 @@ +import itertools import json import random import ssl @@ -274,7 +275,7 @@ def test_drain_events_connection_error(self, parent_on_state_change, cancel_for) assert consumer._pubsub._subscribed_to == {b'celery-task-meta-initial'} -class test_RedisBackend: +class basetest_RedisBackend: def get_backend(self): from celery.backends.redis import RedisBackend @@ -287,11 +288,42 @@ def get_E_LOST(self): from celery.backends.redis import E_LOST return E_LOST + def create_task(self, i, group_id="group_id"): + tid = uuid() + task = Mock(name=f'task-{tid}') + task.name = 'foobarbaz' + self.app.tasks['foobarbaz'] = task + task.request.chord = signature(task) + task.request.id = tid + task.request.chord['chord_size'] = 10 + task.request.group = group_id + task.request.group_index = i + return task + + @contextmanager + def chord_context(self, size=1): + with patch('celery.backends.redis.maybe_signature') as ms: + request = Mock(name='request') + request.id = 'id1' + request.group = 'gid1' + request.group_index = None + tasks = [ + self.create_task(i, group_id=request.group) + for i in range(size) + ] + callback = ms.return_value = Signature('add') + callback.id = 'id1' + callback['chord_size'] = size + callback.delay = Mock(name='callback.delay') + yield tasks, request, callback + def setup(self): self.Backend = self.get_backend() self.E_LOST = self.get_E_LOST() self.b = self.Backend(app=self.app) + +class test_RedisBackend(basetest_RedisBackend): @pytest.mark.usefixtures('depends_on_current_app') def test_reduce(self): pytest.importorskip('redis') @@ -623,20 +655,36 @@ def test_set_no_expire(self): self.b.expires = None self.b._set_with_state('foo', 'bar', states.SUCCESS) - def create_task(self, i): + def test_process_cleanup(self): + self.b.process_cleanup() + + def test_get_set_forget(self): tid = uuid() - task = Mock(name=f'task-{tid}') - task.name = 'foobarbaz' - self.app.tasks['foobarbaz'] = task - task.request.chord = signature(task) - task.request.id = tid - task.request.chord['chord_size'] = 10 - task.request.group = 'group_id' - task.request.group_index = i - return task + self.b.store_result(tid, 42, states.SUCCESS) + assert self.b.get_state(tid) == states.SUCCESS + assert self.b.get_result(tid) == 42 + self.b.forget(tid) + assert self.b.get_state(tid) == states.PENDING + + def test_set_expires(self): + self.b = self.Backend(expires=512, app=self.app) + tid = uuid() + key = self.b.get_key_for_task(tid) + self.b.store_result(tid, 42, states.SUCCESS) + self.b.client.expire.assert_called_with( + key, 512, + ) + - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return(self, restore): +class test_RedisBackend_chords_simple(basetest_RedisBackend): + @pytest.fixture(scope="class", autouse=True) + def simple_header_result(self): + with patch( + "celery.result.GroupResult.restore", return_value=None, + ) as p: + yield p + + def test_on_chord_part_return(self): tasks = [self.create_task(i) for i in range(10)] random.shuffle(tasks) @@ -652,8 +700,7 @@ def test_on_chord_part_return(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return__unordered(self, restore): + def test_on_chord_part_return__unordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=False, ) @@ -673,8 +720,7 @@ def test_on_chord_part_return__unordered(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return__ordered(self, restore): + def test_on_chord_part_return__ordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=True, ) @@ -694,8 +740,7 @@ def test_on_chord_part_return__ordered(self, restore): call(jkey, 86400), call(tkey, 86400), ]) - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry(self, restore): + def test_on_chord_part_return_no_expiry(self): old_expires = self.b.expires self.b.expires = None tasks = [self.create_task(i) for i in range(10)] @@ -712,8 +757,7 @@ def test_on_chord_part_return_no_expiry(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_expire_set_to_zero(self, restore): + def test_on_chord_part_return_expire_set_to_zero(self): old_expires = self.b.expires self.b.expires = 0 tasks = [self.create_task(i) for i in range(10)] @@ -730,8 +774,7 @@ def test_on_chord_part_return_expire_set_to_zero(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry__unordered(self, restore): + def test_on_chord_part_return_no_expiry__unordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=False, ) @@ -752,8 +795,7 @@ def test_on_chord_part_return_no_expiry__unordered(self, restore): self.b.expires = old_expires - @patch('celery.result.GroupResult.restore') - def test_on_chord_part_return_no_expiry__ordered(self, restore): + def test_on_chord_part_return_no_expiry__ordered(self): self.app.conf.result_backend_transport_options = dict( result_chord_ordered=True, ) @@ -926,39 +968,43 @@ def test_on_chord_part_return__other_error__ordered(self): callback.id, exc=ANY, ) - @contextmanager - def chord_context(self, size=1): - with patch('celery.backends.redis.maybe_signature') as ms: - tasks = [self.create_task(i) for i in range(size)] - request = Mock(name='request') - request.id = 'id1' - request.group = 'gid1' - request.group_index = None - callback = ms.return_value = Signature('add') - callback.id = 'id1' - callback['chord_size'] = size - callback.delay = Mock(name='callback.delay') - yield tasks, request, callback - def test_process_cleanup(self): - self.b.process_cleanup() +class test_RedisBackend_chords_complex(basetest_RedisBackend): + @pytest.fixture(scope="function", autouse=True) + def complex_header_result(self): + with patch("celery.result.GroupResult.restore") as p: + yield p - def test_get_set_forget(self): - tid = uuid() - self.b.store_result(tid, 42, states.SUCCESS) - assert self.b.get_state(tid) == states.SUCCESS - assert self.b.get_result(tid) == 42 - self.b.forget(tid) - assert self.b.get_state(tid) == states.PENDING + @pytest.mark.parametrize("supports_native_join", (True, False)) + def test_on_chord_part_return( + self, complex_header_result, supports_native_join, + ): + mock_result_obj = complex_header_result.return_value + mock_result_obj.supports_native_join = supports_native_join - def test_set_expires(self): - self.b = self.Backend(expires=512, app=self.app) - tid = uuid() - key = self.b.get_key_for_task(tid) - self.b.store_result(tid, 42, states.SUCCESS) - self.b.client.expire.assert_called_with( - key, 512, - ) + tasks = [self.create_task(i) for i in range(10)] + random.shuffle(tasks) + + with self.chord_context(10) as (tasks, request, callback): + for task, result_val in zip(tasks, itertools.cycle((42, ))): + self.b.on_chord_part_return( + task.request, states.SUCCESS, result_val, + ) + # Confirm that `zadd` was called even though we won't end up + # using the data pushed into the sorted set + assert self.b.client.zadd.call_count == 1 + self.b.client.zadd.reset_mock() + # Confirm that neither `zrange` not `lrange` were called + self.b.client.zrange.assert_not_called() + self.b.client.lrange.assert_not_called() + # Confirm that the `GroupResult.restore` mock was called + complex_header_result.assert_called_once_with(request.group) + # Confirm the the callback was called with the `join()`ed group result + if supports_native_join: + expected_join = mock_result_obj.join_native + else: + expected_join = mock_result_obj.join + callback.delay.assert_called_once_with(expected_join()) class test_SentinelBackend: