From 0f1a53b84ab15e15bb257c4d9ce2b3459d2ed176 Mon Sep 17 00:00:00 2001 From: Michal Kuffa Date: Tue, 29 Sep 2020 15:49:12 +0200 Subject: [PATCH] Fix custom headers propagation for protocol 1 hybrid messages --- celery/worker/strategy.py | 1 + t/unit/worker/test_request.py | 4 ++-- t/unit/worker/test_strategy.py | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/celery/worker/strategy.py b/celery/worker/strategy.py index 64d3c5337f..8fb1eabd31 100644 --- a/celery/worker/strategy.py +++ b/celery/worker/strategy.py @@ -50,6 +50,7 @@ def hybrid_to_proto2(message, body): 'kwargsrepr': body.get('kwargsrepr'), 'origin': body.get('origin'), } + headers.update(message.headers or {}) embed = { 'callbacks': body.get('callbacks'), diff --git a/t/unit/worker/test_request.py b/t/unit/worker/test_request.py index 039af717b2..3ed7c553d1 100644 --- a/t/unit/worker/test_request.py +++ b/t/unit/worker/test_request.py @@ -1204,8 +1204,8 @@ def test_execute_using_pool_with_none_timelimit_header(self): def test_execute_using_pool__defaults_of_hybrid_to_proto2(self): weakref_ref = Mock(name='weakref.ref') - headers = strategy.hybrid_to_proto2('', {'id': uuid(), - 'task': self.mytask.name})[1] + headers = strategy.hybrid_to_proto2(Mock(headers=None), {'id': uuid(), + 'task': self.mytask.name})[1] job = self.zRequest(revoked_tasks=set(), ref=weakref_ref, **headers) job.execute_using_pool(self.pool) assert job._apply_result diff --git a/t/unit/worker/test_strategy.py b/t/unit/worker/test_strategy.py index 6b93dab74d..88abe4dcd2 100644 --- a/t/unit/worker/test_strategy.py +++ b/t/unit/worker/test_strategy.py @@ -271,7 +271,7 @@ def failed(): class test_hybrid_to_proto2: def setup(self): - self.message = Mock(name='message') + self.message = Mock(name='message', headers={"custom": "header"}) self.body = { 'args': (1,), 'kwargs': {'foo': 'baz'}, @@ -288,3 +288,7 @@ def test_retries_custom_value(self): self.body['retries'] = _custom_value _, headers, _, _ = hybrid_to_proto2(self.message, self.body) assert headers.get('retries') == _custom_value + + def test_custom_headers(self): + _, headers, _, _ = hybrid_to_proto2(self.message, self.body) + assert headers.get("custom") == "header"