-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
test_serialization.py
69 lines (55 loc) · 2.4 KB
/
test_serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import base64
import os
import pytest
from kombu.serialization import registry
from kombu.utils.encoding import bytes_to_str
from celery.exceptions import SecurityError
from celery.security.certificate import Certificate, CertStore
from celery.security.key import PrivateKey
from celery.security.serialization import SecureSerializer, register_auth
from . import CERT1, CERT2, KEY1, KEY2
from .case import SecurityCase
class test_secureserializer(SecurityCase):
def _get_s(self, key, cert, certs, serializer="json"):
store = CertStore()
for c in certs:
store.add_cert(Certificate(c))
return SecureSerializer(
PrivateKey(key), Certificate(cert), store, serializer=serializer
)
@pytest.mark.parametrize("data", [1, "foo", b"foo", {"foo": 1}])
@pytest.mark.parametrize("serializer", ["json", "pickle"])
def test_serialize(self, data, serializer):
s = self._get_s(KEY1, CERT1, [CERT1], serializer=serializer)
assert s.deserialize(s.serialize(data)) == data
def test_deserialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
with pytest.raises(SecurityError):
s.deserialize('bad data')
def test_unmatched_key_cert(self):
s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
with pytest.raises(SecurityError):
s.deserialize(s.serialize('foo'))
def test_unknown_source(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY1, CERT1, [])
with pytest.raises(SecurityError):
s1.deserialize(s1.serialize('foo'))
with pytest.raises(SecurityError):
s2.deserialize(s2.serialize('foo'))
def test_self_send(self):
s1 = self._get_s(KEY1, CERT1, [CERT1])
s2 = self._get_s(KEY1, CERT1, [CERT1])
assert s2.deserialize(s1.serialize('foo')) == 'foo'
def test_separate_ends(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY2, CERT2, [CERT1])
assert s2.deserialize(s1.serialize('foo')) == 'foo'
def test_register_auth(self):
register_auth(KEY1, None, CERT1, '')
assert 'application/data' in registry._decoders
def test_lots_of_sign(self):
for i in range(1000):
rdata = bytes_to_str(base64.urlsafe_b64encode(os.urandom(265)))
s = self._get_s(KEY1, CERT1, [CERT1])
assert s.deserialize(s.serialize(rdata)) == rdata