diff --git a/.gitignore b/.gitignore index 2002142..6c6327b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .*.swp /.idea/ +/.vscode/ /dist/ /distribute*.tar.gz diff --git a/doc/compatibility.rst b/doc/compatibility.rst index 1429553..2545519 100644 --- a/doc/compatibility.rst +++ b/doc/compatibility.rst @@ -5,9 +5,10 @@ Compatibility with standards .. index:: compatibility Python-RSA implements encryption and signatures according to PKCS#1 -version 1.5. This makes it compatible with the OpenSSL RSA module. +version 1.5. Additionally, Python-RSA implements multiprime encryption according to PKCS#1 +version 2.1. This makes it largely compatible with the OpenSSL RSA module. -Keys are stored in PEM or DER format according to PKCS#1 v1.5. Private +Keys are stored in PEM or DER format according to PKCS#1 v2.1. Private keys are compatible with OpenSSL. However, OpenSSL uses X.509 for its public keys, which are not supported. diff --git a/doc/index.rst b/doc/index.rst index a0a1573..47c0084 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -8,7 +8,8 @@ Welcome to Python-RSA's documentation! Python-RSA is a pure-Python RSA implementation. It supports encryption and decryption, signing and verifying signatures, and key -generation according to PKCS#1 version 1.5. +generation according to PKCS#1 version 1.5. Additionally, Python-RSA +implements multirime encryption according to PKCS#1 version 2.1. If you have the time and skill to improve the implementation, by all means be my guest. The best way is to clone the `Git diff --git a/doc/usage.rst b/doc/usage.rst index 5c60e2d..946014e 100644 --- a/doc/usage.rst +++ b/doc/usage.rst @@ -40,6 +40,11 @@ Alternatively you can use :py:meth:`rsa.PrivateKey.load_pkcs1` and ... keydata = privatefile.read() >>> privkey = rsa.PrivateKey.load_pkcs1(keydata) +Python-RSA also allows you to generate RSA keys with multiple primes: + + >>> import rsa + >>> (pubkey, privkey) = rsa.newkeys(512, nprimes=3) + Time to generate a key ++++++++++++++++++++++ diff --git a/rsa/core.py b/rsa/core.py index 38b526c..7649f11 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -17,6 +17,8 @@ This is the actual core RSA implementation, which is only defined mathematically on integers. """ +import itertools +import typing def assert_int(var: int, name: str) -> None: @@ -51,3 +53,37 @@ def decrypt_int(cyphertext: int, dkey: int, n: int) -> int: message = pow(cyphertext, dkey, n) return message + + +def decrypt_int_fast( + cyphertext: int, + rs: typing.List[int], + ds: typing.List[int], + ts: typing.List[int], +) -> int: + """Decrypts a cypher text more quickly using the Chinese Remainder Theorem.""" + + assert_int(cyphertext, "cyphertext") + for r in rs: + assert_int(r, "r") + for d in ds: + assert_int(d, "d") + for t in ts: + assert_int(t, "t") + + p, q, rs = rs[0], rs[1], rs[2:] + exp1, exp2, ds = ds[0], ds[1], ds[2:] + coef, ts = ts[0], ts[1:] + + M1 = pow(cyphertext, exp1, p) + M2 = pow(cyphertext, exp2, q) + h = ((M1 - M2) * coef) % p + m = M2 + q * h + + Ms = [pow(cyphertext, d, r) for d, r in zip(ds, rs)] + Rs = list(itertools.accumulate([p, q] + rs, lambda x, y: x*y)) + for R, r, M, t in zip(Rs[1:], rs, Ms, ts): + h = ((M - m) * t) % r + m += R * h + + return m \ No newline at end of file diff --git a/rsa/key.py b/rsa/key.py index 37e26b0..fd30447 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -32,9 +32,11 @@ """ import abc +import math import threading import typing import warnings +import itertools import rsa.prime import rsa.pem @@ -389,7 +391,9 @@ class PrivateKey(AbstractKey): """Represents a private RSA key. This key is also known as the 'decryption key'. It contains the 'n', 'e', - 'd', 'p', 'q' and other values. + 'd', 'p', 'q' and other values. For example ,in the case of multiprime RSA, + it additionally contains the lists 'rs', 'ds', and 'ts' which contain the + factors, exponents, and coefficients for the other primes. Supports attributes as well as dictionary-like access. Attribute access is faster, though. @@ -409,9 +413,19 @@ class PrivateKey(AbstractKey): """ - __slots__ = ("d", "p", "q", "exp1", "exp2", "coef") + __slots__ = ("d", "p", "q", "exp1", "exp2", "coef", "rs", "ds", "ts") + + def __init__( + self, + n: int, + e: int, + d: int, + p: int, + q: int, + rs: typing.Optional[typing.List[int]] = None, + ) -> None: + rs = [] if rs is None else rs - def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: AbstractKey.__init__(self, n, e) self.d = d self.p = p @@ -422,25 +436,72 @@ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: self.exp2 = int(d % (q - 1)) self.coef = rsa.common.inverse(q, p) + # Calculate other primes' exponents and coefficients. + self.rs = rs + self.ds = [int(d % (r - 1)) for r in rs] + Rs = list(itertools.accumulate([p, q] + rs, lambda x, y: x*y)) + self.ts = [pow(R, -1, r) for R, r in zip(Rs[1:], rs)] + def __getitem__(self, key: str) -> int: return getattr(self, key) def __repr__(self) -> str: - return "PrivateKey(%i, %i, %i, %i, %i)" % ( - self.n, - self.e, - self.d, - self.p, - self.q, - ) + if self.rs: + return "PrivateKey(%i, %i, %i, %i, %i, %s)" % ( + self.n, + self.e, + self.d, + self.p, + self.q, + self.rs, + ) + else: + return "PrivateKey(%i, %i, %i, %i, %i)" % ( + self.n, + self.e, + self.d, + self.p, + self.q, + ) - def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]: + def __getstate__(self) -> typing.Tuple: """Returns the key as tuple for pickling.""" - return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef + if self.rs: + return ( + self.n, + self.e, + self.d, + self.p, + self.q, + self.exp1, + self.exp2, + self.coef, + self.rs, + self.ds, + self.ts, + ) + else: + return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef - def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None: + def __setstate__(self, state: typing.Tuple) -> None: """Sets the key from tuple.""" - self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state + if len(state) != 8: + ( + self.n, + self.e, + self.d, + self.p, + self.q, + self.exp1, + self.exp2, + self.coef, + self.rs, + self.ds, + self.ts, + ) = state + else: + self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state + self.rs = self.ds = self.ts = [] AbstractKey.__init__(self, self.n, self.e) def __eq__(self, other: typing.Any) -> bool: @@ -450,22 +511,28 @@ def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, PrivateKey): return False - return ( - self.n == other.n - and self.e == other.e - and self.d == other.d - and self.p == other.p - and self.q == other.q - and self.exp1 == other.exp1 - and self.exp2 == other.exp2 - and self.coef == other.coef - ) + return all([getattr(self, k) == getattr(other, k) for k in self.__slots__]) def __ne__(self, other: typing.Any) -> bool: return not (self == other) def __hash__(self) -> int: - return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) + if self.rs: + return hash(( + self.n, + self.e, + self.d, + self.p, + self.q, + self.exp1, + self.exp2, + self.coef, + *self.rs, + *self.ds, + *self.ts + )) + else: + return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef)) def blinded_decrypt(self, encrypted: int) -> int: """Decrypts the message using blinding to prevent side-channel attacks. @@ -479,19 +546,14 @@ def blinded_decrypt(self, encrypted: int) -> int: # Blinding and un-blinding should be using the same factor blinded, blindfac_inverse = self.blind(encrypted) - - # Instead of using the core functionality, use the Chinese Remainder - # Theorem and be 2-4x faster. This the same as: - # - # decrypted = rsa.core.decrypt_int(blinded, self.d, self.n) - s1 = pow(blinded, self.exp1, self.p) - s2 = pow(blinded, self.exp2, self.q) - h = ((s1 - s2) * self.coef) % self.p - decrypted = s2 + self.q * h - + decrypted = rsa.core.decrypt_int_fast( + blinded, + [self.p, self.q] + self.rs, + [self.exp1, self.exp2] + self.ds, + [self.coef] + self.ts, + ) return self.unblind(decrypted, blindfac_inverse) - @classmethod def _load_pkcs1_der(cls, keyfile: bytes) -> "PrivateKey": """Loads a key in PKCS#1 DER format. @@ -536,12 +598,15 @@ def _load_pkcs1_der(cls, keyfile: bytes) -> "PrivateKey": if priv[0] != 0: raise ValueError("Unable to read this file, version %s != 0" % priv[0]) - as_ints = map(int, priv[1:6]) - key = cls(*as_ints) - + n, e, d, p, q = map(int, priv[1:6]) exp1, exp2, coef = map(int, priv[6:9]) + rs = map(int, priv[9::3]) + ds = map(int, priv[10::3]) + ts = map(int, priv[11::3]) - if (key.exp1, key.exp2, key.coef) != (exp1, exp2, coef): + key = cls(n, e, d, p, q, list(rs)) + + if (key.exp1, key.exp2, key.coef, key.rs, key.ds, key.ts) != (exp1, exp2, coef, rs, ds, ts): warnings.warn( "You have provided a malformed keyfile. Either the exponents " "or the coefficient are incorrect. Using the correct values " @@ -561,6 +626,14 @@ def _save_pkcs1_der(self) -> bytes: from pyasn1.type import univ, namedtype from pyasn1.codec.der import encoder + other_fields = [ + ( + namedtype.NamedType("prime%d" % (i + 3), univ.Integer()), + namedtype.NamedType("exponent%d" % (i + 3), univ.Integer()), + namedtype.NamedType("coefficient%d" % (i + 3), univ.Integer()), + ) for i in range(len(self.rs)) + ] + class AsnPrivKey(univ.Sequence): componentType = namedtype.NamedTypes( namedtype.NamedType("version", univ.Integer()), @@ -572,6 +645,7 @@ class AsnPrivKey(univ.Sequence): namedtype.NamedType("exponent1", univ.Integer()), namedtype.NamedType("exponent2", univ.Integer()), namedtype.NamedType("coefficient", univ.Integer()), + *list(itertools.chain(*other_fields)) ) # Create the ASN object @@ -585,6 +659,10 @@ class AsnPrivKey(univ.Sequence): asn_key.setComponentByName("exponent1", self.exp1) asn_key.setComponentByName("exponent2", self.exp2) asn_key.setComponentByName("coefficient", self.coef) + for i, (r, d, t) in enumerate(zip(self.rs, self.ds, self.ts), start=3): + asn_key.setComponentByName("prime%d" % i, r) + asn_key.setComponentByName("exponent%d" % i, d) + asn_key.setComponentByName("coefficient%d" % i, t) return encoder.encode(asn_key) @@ -615,6 +693,35 @@ def _save_pkcs1_pem(self) -> bytes: return rsa.pem.save_pem(der, b"RSA PRIVATE KEY") +def find_primes( + nbits: int, + getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, + accurate: bool = True, + nprimes: int = 2, +) -> typing.List[int]: + """Returns a list of different primes with nbits divided evenly among them. + + :param nbits: the number of bits for the primes to sum to. + :param getprime_func: the getprime function, defaults to + :py:func:`rsa.prime.getprime`. + :param accurate: whether to enable accurate mode or not. + :returns: list of primes in descending order. + + """ + if nprimes == 2: + return list(find_p_q(nbits // 2, getprime_func, accurate)) + + quo, rem = divmod(nbits, nprimes) + factor_lengths = [quo + 1]*rem + [quo] * (nprimes - rem) + + while True: + primes = [getprime_func(length) for length in factor_lengths] + if len(set(primes)) == len(primes): + break + + return list(reversed(sorted(primes))) + + def find_p_q( nbits: int, getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, @@ -695,7 +802,12 @@ def is_acceptable(p: int, q: int) -> bool: return max(p, q), min(p, q) -def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tuple[int, int]: +def calculate_keys_custom_exponent( + p: int, + q: int, + exponent: int, + rs: typing.Optional[typing.List[int]] = None, +) -> typing.Tuple[int, int]: """Calculates an encryption and a decryption key given p, q and an exponent, and returns them as a tuple (e, d) @@ -705,10 +817,11 @@ def calculate_keys_custom_exponent(p: int, q: int, exponent: int) -> typing.Tupl what you're doing, as the exponent influences how difficult your private key can be cracked. A very common choice for e is 65537. :type exponent: int + :param rs: the list of other large primes """ - phi_n = (p - 1) * (q - 1) + phi_n = math.prod([x - 1 for x in [p, q] + ([] if rs is None else rs)]) try: d = rsa.common.inverse(exponent, phi_n) @@ -747,8 +860,9 @@ def gen_keys( getprime_func: typing.Callable[[int], int], accurate: bool = True, exponent: int = DEFAULT_EXPONENT, -) -> typing.Tuple[int, int, int, int]: - """Generate RSA keys of nbits bits. Returns (p, q, e, d). + nprimes: int = 2, +) -> typing.Tuple: + """Generate RSA keys of nbits bits. Returns (p, q, e, d) or (p, q, e, d, rs). Note: this can take a long time, depending on the key size. @@ -760,19 +874,24 @@ def gen_keys( what you're doing, as the exponent influences how difficult your private key can be cracked. A very common choice for e is 65537. :type exponent: int + :param nprimes: the number of prime factors comprising the modulus. """ - # Regenerate p and q values, until calculate_keys doesn't raise a + # Regenerate prime values, until calculate_keys_custom_exponent doesn't raise a # ValueError. while True: - (p, q) = find_p_q(nbits // 2, getprime_func, accurate) + primes = find_primes(nbits, getprime_func, accurate, nprimes) + p, q, rs = primes[0], primes[1], primes[2:] try: - (e, d) = calculate_keys_custom_exponent(p, q, exponent=exponent) + (e, d) = calculate_keys_custom_exponent(p, q, exponent=exponent, rs=rs) break except ValueError: pass - return p, q, e, d + if rs: + return p, q, e, d, rs + else: + return p, q, e, d def newkeys( @@ -780,6 +899,7 @@ def newkeys( accurate: bool = True, poolsize: int = 1, exponent: int = DEFAULT_EXPONENT, + nprimes: int = 2, ) -> typing.Tuple[PublicKey, PrivateKey]: """Generates public and private keys, and returns them as (pub, priv). @@ -787,7 +907,7 @@ def newkeys( :py:class:`rsa.PublicKey` object. The private key is also known as the 'decryption key' and is a :py:class:`rsa.PrivateKey` object. - :param nbits: the number of bits required to store ``n = p*q``. + :param nbits: the number of bits required to store the modulus ``n``. :param accurate: when True, ``n`` will have exactly the number of bits you asked for. However, this makes key generation much slower. When False, `n`` may have slightly less bits. @@ -798,6 +918,7 @@ def newkeys( what you're doing, as the exponent influences how difficult your private key can be cracked. A very common choice for e is 65537. :type exponent: int + :param nprimes: the number of prime factors comprising the modulus. :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`) @@ -812,6 +933,9 @@ def newkeys( if poolsize < 1: raise ValueError("Pool size (%i) should be >= 1" % poolsize) + if nprimes < 2: + raise ValueError("Number of primes (%i) should be >= 2" % nprimes) + # Determine which getprime function to use if poolsize > 1: from rsa import parallel @@ -823,12 +947,17 @@ def getprime_func(nbits: int) -> int: getprime_func = rsa.prime.getprime # Generate the key components - (p, q, e, d) = gen_keys(nbits, getprime_func, accurate=accurate, exponent=exponent) + result = gen_keys(nbits, getprime_func, accurate=accurate, exponent=exponent, nprimes=nprimes) + if len(result) == 4: + p, q, e, d = result + rs = [] + else: + p, q, e, d, rs = result # Create the key objects - n = p * q + n = math.prod([p, q] + rs) - return (PublicKey(n, e), PrivateKey(n, e, d, p, q)) + return (PublicKey(n, e), PrivateKey(n, e, d, p, q, rs)) __all__ = ["PublicKey", "PrivateKey", "newkeys"] diff --git a/tests/test_key.py b/tests/test_key.py index c570830..0975a4f 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -75,6 +75,19 @@ def getprime(_): self.assertEqual(39317, p) self.assertEqual(33107, q) + def test_multiprime(self): + primes = [64123, 50957, 39317, 33107] + exponent = 2**2**4 + 1 + + def getprime(_): + return primes.pop(0) + (p, q, e, d, rs) = rsa.key.gen_keys( + 128, accurate=False, getprime_func=getprime, exponent=exponent, nprimes=4 + ) + self.assertEqual(64123, p) + self.assertEqual(50957, q) + self.assertEqual(rs, [39317, 33107]) + class HashTest(unittest.TestCase): """Test hashing of keys""" diff --git a/tests/test_load_save_keys.py b/tests/test_load_save_keys.py index 3e90999..557066f 100644 --- a/tests/test_load_save_keys.py +++ b/tests/test_load_save_keys.py @@ -209,6 +209,155 @@ def test_load_from_disk(self): self.assertEqual(14617195220284816877, privkey.q) +MP_B64PRIV_DER = b"MEoCAQACCDsGeWW1Om5xAgMBAAECBQDHn4npAgMA+nsCAwDHDQICcw0CAwCWDQICTjYCAwCZlQICaukCAioVAgMAgVMCAksjAgIq3g==" +MP_PRIVATE_DER = base64.standard_b64decode(MP_B64PRIV_DER) + +MP_B64PUB_DER = b"MA8CCDsGeWW1Om5xAgMBAAE=" +MP_PUBLIC_DER = base64.standard_b64decode(MP_B64PUB_DER) + +MP_PRIVATE_PEM = ( + b"""\ +-----BEGIN CONFUSING STUFF----- +Cruft before the key + +-----BEGIN RSA PRIVATE KEY----- +Comment: something blah + +""" + + MP_B64PRIV_DER + + b""" +-----END RSA PRIVATE KEY----- + +Stuff after the key +-----END CONFUSING STUFF----- +""" +) + +MP_CLEAN_PRIVATE_PEM = ( + b"""\ +-----BEGIN RSA PRIVATE KEY----- +""" + + MP_B64PRIV_DER + + b""" +-----END RSA PRIVATE KEY----- +""" +) + +MP_PUBLIC_PEM = ( + b"""\ +-----BEGIN CONFUSING STUFF----- +Cruft before the key + +-----BEGIN RSA PUBLIC KEY----- +Comment: something blah + +""" + + MP_B64PUB_DER + + b""" +-----END RSA PUBLIC KEY----- + +Stuff after the key +-----END CONFUSING STUFF----- +""" +) + +MP_CLEAN_PUBLIC_PEM = ( + b"""\ +-----BEGIN RSA PUBLIC KEY----- +""" + + MP_B64PUB_DER + + b""" +-----END RSA PUBLIC KEY----- +""" +) + + +class MultiprimeDerTest(unittest.TestCase): + """Test saving and loading multiprime DER keys.""" + + def test_load_multiprime_private_key(self): + """Test loading private DER keys.""" + + key = rsa.key.PrivateKey.load_pkcs1(MP_PRIVATE_DER, "DER") + expected = rsa.key.PrivateKey(4253220375837175409, 65537, 3349121513, 64123, 50957, [39317, 33107]) + + self.assertEqual(expected, key) + self.assertEqual(key.exp1, 29453) + self.assertEqual(key.exp2, 38413) + self.assertEqual(key.coef, 20022) + self.assertEqual(key.ds, [27369, 19235]) + self.assertEqual(key.ts, [10773, 10974]) + + def test_save_multiprime_private_key(self): + """Test saving private DER keys.""" + + key = rsa.key.PrivateKey(4253220375837175409, 65537, 3349121513, 64123, 50957, [39317, 33107]) + der = key.save_pkcs1("DER") + + self.assertIsInstance(der, bytes) + self.assertEqual(MP_PRIVATE_DER, der) + + def test_load_multiprime_public_key(self): + """Test loading public DER keys.""" + + key = rsa.key.PublicKey.load_pkcs1(MP_PUBLIC_DER, "DER") + expected = rsa.key.PublicKey(4253220375837175409, 65537) + + self.assertEqual(expected, key) + + def test_save_multiprime_public_key(self): + """Test saving public DER keys.""" + + key = rsa.key.PublicKey(4253220375837175409, 65537) + der = key.save_pkcs1("DER") + + self.assertIsInstance(der, bytes) + self.assertEqual(MP_PUBLIC_DER, der) + + +class MultiprimePemTest(unittest.TestCase): + """Test saving and loading multiprime PEM keys.""" + + def test_load_multiprime_private_key(self): + """Test loading private PEM files.""" + + key = rsa.key.PrivateKey.load_pkcs1(MP_PRIVATE_PEM, "PEM") + expected = rsa.key.PrivateKey(4253220375837175409, 65537, 3349121513, 64123, 50957, [39317, 33107]) + + self.assertEqual(expected, key) + self.assertEqual(key.exp1, 29453) + self.assertEqual(key.exp2, 38413) + self.assertEqual(key.coef, 20022) + self.assertEqual(key.ds, [27369, 19235]) + self.assertEqual(key.ts, [10773, 10974]) + + def test_save_multiprime_private_key(self): + """Test saving private PEM files.""" + + key = rsa.key.PrivateKey(4253220375837175409, 65537, 3349121513, 64123, 50957, [39317, 33107]) + pem = key.save_pkcs1("PEM") + + self.assertIsInstance(pem, bytes) + self.assertEqual(MP_CLEAN_PRIVATE_PEM.replace(b"\n", b""), pem.replace(b"\n", b"")) + + def test_load_multiprime_public_key(self): + """Test loading public PEM files.""" + + key = rsa.key.PublicKey.load_pkcs1(MP_PUBLIC_PEM, "PEM") + expected = rsa.key.PublicKey(4253220375837175409, 65537) + + self.assertEqual(expected, key) + + def test_save_multiprime_public_key(self): + """Test saving public PEM files.""" + + key = rsa.key.PublicKey(4253220375837175409, 65537) + pem = key.save_pkcs1("PEM") + + self.assertIsInstance(pem, bytes) + self.assertEqual(MP_CLEAN_PUBLIC_PEM, pem) + + class PickleTest(unittest.TestCase): """Test saving and loading keys by pickling.""" @@ -222,6 +371,16 @@ def test_private_key(self): for attr in rsa.key.AbstractKey.__slots__: self.assertTrue(hasattr(unpickled, attr)) + def test_multiprime_private_key(self): + pk = rsa.key.PrivateKey(4253220375837175409, 65537, 3349121513, 64123, 50957, [39317, 33107]) + + pickled = pickle.dumps(pk) + unpickled = pickle.loads(pickled) + self.assertEqual(pk, unpickled) + + for attr in rsa.key.AbstractKey.__slots__: + self.assertTrue(hasattr(unpickled, attr)) + def test_public_key(self): pk = rsa.key.PublicKey(3727264081, 65537) diff --git a/tests/test_pkcs1.py b/tests/test_pkcs1.py index 1ce819f..b79eab9 100644 --- a/tests/test_pkcs1.py +++ b/tests/test_pkcs1.py @@ -63,6 +63,23 @@ def test_randomness(self): self.assertNotEqual(encrypted1, encrypted2) +class MultiprimeBinaryTest(unittest.TestCase): + def setUp(self): + (self.pub, self.priv) = rsa.newkeys(256, nprimes=3) + + def test_multiprime_enc_dec(self): + message = struct.pack(">IIII", 0, 0, 0, 1) + print("\n\tMessage: %r" % message) + + encrypted = pkcs1.encrypt(message, self.pub) + print("\tEncrypted: %r" % encrypted) + + decrypted = pkcs1.decrypt(encrypted, self.priv) + print("\tDecrypted: %r" % decrypted) + + self.assertEqual(message, decrypted) + + class ExtraZeroesTest(unittest.TestCase): def setUp(self): # Key, cyphertext, and plaintext taken from https://github.com/sybrenstuvel/python-rsa/issues/146 @@ -182,6 +199,18 @@ def test_apppend_zeroes(self): pkcs1.verify(message, signature, self.pub) +class MultiprimeSignatureTest(unittest.TestCase): + def setUp(self): + (self.pub, self.priv) = rsa.newkeys(512, nprimes=3) + + def test_multiprime_sign_verify(self): + """Test happy flow of sign and verify""" + + message = b"je moeder" + signature = pkcs1.sign(message, self.priv, "SHA-256") + self.assertEqual("SHA-256", pkcs1.verify(message, signature, self.pub)) + + class PaddingSizeTest(unittest.TestCase): def test_too_little_padding(self): """Padding less than 8 bytes should be rejected."""