diff --git a/celery/security/serialization.py b/celery/security/serialization.py index c58ef906542..937abe63c72 100644 --- a/celery/security/serialization.py +++ b/celery/security/serialization.py @@ -29,7 +29,8 @@ def serialize(self, data): assert self._cert is not None with reraise_errors('Unable to serialize: {0!r}', (Exception,)): content_type, content_encoding, body = dumps( - bytes_to_str(data), serializer=self._serializer) + data, serializer=self._serializer) + # What we sign is the serialized body, not the body itself. # this way the receiver doesn't have to decode the contents # to verify the signature (and thus avoiding potential flaws @@ -48,7 +49,7 @@ def deserialize(self, data): payload['signer'], payload['body']) self._cert_store[signer].verify(body, signature, self._digest) - return loads(bytes_to_str(body), payload['content_type'], + return loads(body, payload['content_type'], payload['content_encoding'], force=True) def _pack(self, body, content_type, content_encoding, signer, signature, @@ -84,7 +85,7 @@ def _unpack(self, payload, sep=str_to_bytes('\x00\x01')): 'signature': signature, 'content_type': bytes_to_str(v[0]), 'content_encoding': bytes_to_str(v[1]), - 'body': bytes_to_str(v[2]), + 'body': v[2], } diff --git a/t/unit/security/test_serialization.py b/t/unit/security/test_serialization.py index 6caf3857b81..cb16d9f14fc 100644 --- a/t/unit/security/test_serialization.py +++ b/t/unit/security/test_serialization.py @@ -16,15 +16,19 @@ class test_secureserializer(SecurityCase): - def _get_s(self, key, cert, certs): + def _get_s(self, key, cert, certs, serializer="json"): store = CertStore() for c in certs: store.add_cert(Certificate(c)) - return SecureSerializer(PrivateKey(key), Certificate(cert), store) + return SecureSerializer( + PrivateKey(key), Certificate(cert), store, serializer=serializer + ) - def test_serialize(self): - s = self._get_s(KEY1, CERT1, [CERT1]) - assert s.deserialize(s.serialize('foo')) == 'foo' + @pytest.mark.parametrize("data", [1, "foo", b"foo", {"foo": 1}]) + @pytest.mark.parametrize("serializer", ["json", "pickle"]) + def test_serialize(self, data, serializer): + s = self._get_s(KEY1, CERT1, [CERT1], serializer=serializer) + assert s.deserialize(s.serialize(data)) == data def test_deserialize(self): s = self._get_s(KEY1, CERT1, [CERT1])