Skip to content

Commit

Permalink
security: SecureSerializer: support generic low-level serializers
Browse files Browse the repository at this point in the history
  • Loading branch information
shirsa authored and Nusnus committed May 2, 2024
1 parent 7ce2e41 commit 2335e29
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
7 changes: 4 additions & 3 deletions celery/security/serialization.py
Expand Up @@ -29,7 +29,8 @@ def serialize(self, data):
assert self._cert is not None
with reraise_errors('Unable to serialize: {0!r}', (Exception,)):
content_type, content_encoding, body = dumps(
bytes_to_str(data), serializer=self._serializer)
data, serializer=self._serializer)

# What we sign is the serialized body, not the body itself.
# this way the receiver doesn't have to decode the contents
# to verify the signature (and thus avoiding potential flaws
Expand All @@ -48,7 +49,7 @@ def deserialize(self, data):
payload['signer'],
payload['body'])
self._cert_store[signer].verify(body, signature, self._digest)
return loads(bytes_to_str(body), payload['content_type'],
return loads(body, payload['content_type'],
payload['content_encoding'], force=True)

def _pack(self, body, content_type, content_encoding, signer, signature,
Expand Down Expand Up @@ -84,7 +85,7 @@ def _unpack(self, payload, sep=str_to_bytes('\x00\x01')):
'signature': signature,
'content_type': bytes_to_str(v[0]),
'content_encoding': bytes_to_str(v[1]),
'body': bytes_to_str(v[2]),
'body': v[2],
}


Expand Down
14 changes: 9 additions & 5 deletions t/unit/security/test_serialization.py
Expand Up @@ -16,15 +16,19 @@

class test_secureserializer(SecurityCase):

def _get_s(self, key, cert, certs):
def _get_s(self, key, cert, certs, serializer="json"):
store = CertStore()
for c in certs:
store.add_cert(Certificate(c))
return SecureSerializer(PrivateKey(key), Certificate(cert), store)
return SecureSerializer(
PrivateKey(key), Certificate(cert), store, serializer=serializer
)

def test_serialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
assert s.deserialize(s.serialize('foo')) == 'foo'
@pytest.mark.parametrize("data", [1, "foo", b"foo", {"foo": 1}])
@pytest.mark.parametrize("serializer", ["json", "pickle"])
def test_serialize(self, data, serializer):
s = self._get_s(KEY1, CERT1, [CERT1], serializer=serializer)
assert s.deserialize(s.serialize(data)) == data

def test_deserialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
Expand Down

0 comments on commit 2335e29

Please sign in to comment.