Skip to content

Commit

Permalink
This resolves issue sybrenstuvel#98
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardThiessen committed Jun 4, 2023
1 parent 771a0b0 commit c781234
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 51 deletions.
65 changes: 29 additions & 36 deletions rsa/key.py
Expand Up @@ -617,7 +617,7 @@ def _save_pkcs1_pem(self) -> bytes:

def find_p_q(
nbits: int,
getprime_func: typing.Callable[[int], int] = rsa.prime.getprime,
getprime_func: typing.Callable[[int], int] = rsa.prime.getprime_FIPS,
accurate: bool = True,
) -> typing.Tuple[int, int]:
"""Returns a tuple of two different primes of nbits bits each.
Expand All @@ -627,7 +627,7 @@ def find_p_q(
:param nbits: the number of bits in each of p and q.
:param getprime_func: the getprime function, defaults to
:py:func:`rsa.prime.getprime`.
:py:func:`rsa.prime.getprime`.#TODO:update
*Introduced in Python-RSA 3.1*
Expand All @@ -650,45 +650,38 @@ def find_p_q(
"""

total_bits = nbits * 2

# Make sure that p and q aren't too close or the factoring programs can
# factor n.
shift = nbits // 16
pbits = nbits + shift
qbits = nbits - shift

# Choose the two initial primes
p = getprime_func(pbits)
q = getprime_func(qbits)
total_bits = nbits * 2

def is_acceptable(p: int, q: int) -> bool:
"""Returns True iff p and q are acceptable:
- p and q differ
- (p * q) has the right nr of bits (when accurate=True)
NIST.FIPS.186-4 acceptance criteria:
- https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf#%5B%7B%22num%22%3A127%2C%22gen%22%3A0%7D%2C%7B%22name%22%3A%22XYZ%22%7D%2C70%2C223%2C0%5D
- p and q are in the range [2**(nbits+0.5),2**nbits]
- abs(p-q)>2**(nbits-100)
- the strong prime requirements are ignored
- http://people.csail.mit.edu/rivest/RivestSilverman-AreStrongPrimesNeededForRSA.pdf
- (the paper argues they're not needed)
- provable prime requirements for smaller key sizes are ignored
"""

if p == q:
return False

if not accurate:
return True

# Make sure we have just the right amount of bits
end=2**nbits
start=rsa.prime._bigint_divide_by_sqrt_2(end)
if not start<=p<end:return False
if not start<=q<end:return False
if not abs(p-q)>2**(nbits-100):return False
found_size = rsa.common.bit_size(p * q)
return total_bits == found_size

# Keep choosing other primes until they match our requirements.
change_p = False
while not is_acceptable(p, q):
# Change p on one iteration and q on the other
if change_p:
p = getprime_func(pbits)
else:
q = getprime_func(qbits)

change_p = not change_p

#this should be guaranteed by the previous range checks
assert found_size==total_bits
return True

while 1:
# Keep generating primes if there's a failure.
p = getprime_func(nbits)
q = getprime_func(nbits)
#note:failure has negligible probability
if is_acceptable(p, q):break

# We want p > q as described on
# http://www.di-mgt.com.au/rsa_alg.html#crt
Expand Down
24 changes: 9 additions & 15 deletions rsa/parallel.py
Expand Up @@ -22,24 +22,16 @@
"""

import typing
import multiprocessing as mp
from multiprocessing.connection import Connection

import rsa.prime
import rsa.randnum


def _find_prime(nbits: int, pipe: Connection) -> None:
while True:
integer = rsa.randnum.read_random_odd_int(nbits)

# Test for primeness
if rsa.prime.is_prime(integer):
pipe.send(integer)
return


def getprime(nbits: int, poolsize: int) -> int:
def getprime(nbits: int,
poolsize: int,
getprime_func: typing.Callable[[int], int] = rsa.prime.getprime_FIPS
) -> int:
"""Returns a prime number that can be stored in 'nbits' bits.
Works in multiple threads at the same time.
Expand All @@ -62,7 +54,9 @@ def getprime(nbits: int, poolsize: int) -> int:

# Create processes
try:
procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send)) for _ in range(poolsize)]
#target function
target=lambda:pipe_send.send(getprime_func(nbits))
procs = [mp.Process(target=target) for _ in range(poolsize)]
# Start processes
for p in procs:
p.start()
Expand All @@ -85,7 +79,7 @@ def getprime(nbits: int, poolsize: int) -> int:
print("Running doctests 1000x or until failure")
import doctest

for count in range(100):
for count in range(1000):
(failures, tests) = doctest.testmod()
if failures:
break
Expand Down
92 changes: 92 additions & 0 deletions rsa/prime.py
Expand Up @@ -182,6 +182,98 @@ def are_relatively_prime(a: int, b: int) -> bool:
d = gcd(a, b)
return d == 1

def _bigint_divide_by_sqrt_2(n):
"""returns math.ceil(n/sqrt(2))
result is exact
:param n: integer to divide by sqrt(2).
:returns: math.ceil(n/sqrt(2))
>>> import math
>>> _bigint_divide_by_sqrt_2(100)
71
>>> _bigint_divide_by_sqrt_2(2**128)
240615969168004511545033772477625056928
"""
x=n*int(2**63.5)//2**64 #initial approximation
target=n**2 // 2
while 1:#newton approximation
dx=(x*x-target)//(2*x)
x-=dx
if not dx:break
#check this meets the ceil criteria exactly
assert (x )**2>target
assert (x-1)**2<target
return x

def getprimebyrange(start,end,initial=None):
"""Returns a prime number randomly chosen from range(start,end)
randomly chooses an initial point within the range
This can be overriden with the optional initial argument
starts at the initial point scanning range(initial,end) then trying
range(start,initial)
>>> p = getprimebyrange(100,200)
>>> 100<=p<200
True
>>> is_prime(p-1)
False
>>> is_prime(p)
True
>>> is_prime(p+1)
False
>>> getprimebyrange(10000,20000,initial=10000)
10007
>>> getprimebyrange(10000,20000,initial=10010)
10037
>>> #when no primes in range(initial,end), it tries range(start,initial)
>>> getprimebyrange(10000,10020,initial=10010)
10007
"""
#randomly choose the initial point in the range (unless specified)
assert end>start
if initial is None:
initial=start+rsa.randnum.randint(end-start)-1
assert start<=initial<end
#check top part of range
for candidate in range(initial|1, end,2):
if is_prime(candidate):
return candidate
#nothing in the top part of the given range
#check bottom part of range
for candidate in range(start|1,initial,2):
#integer = rsa.randnum.read_random_odd_int(nbits)
# Test for primeness
if is_prime(candidate):
return candidate
#nothing the bottom half either
raise ValueError("no primes in range")

def getprime_FIPS(nbits):
"""Returns a prime number in the range ceil(2**n/sqrt(2)) <= x < 2**n
the product of two such primes will always be between [2**(2*n),2**(2*n-1))]
>>> p = getprime_FIPS(128)
>>> is_prime(p-1)
False
>>> is_prime(p)
True
>>> is_prime(p+1)
False
>>> from rsa import common
>>> common.bit_size(p) == 128
True
>>> common.bit_size(p**2) == 256
True
"""
end=2**nbits
start=_bigint_divide_by_sqrt_2(end)
p=getprimebyrange(start,end)
return p



if __name__ == "__main__":
print("Running doctests 1000x or until failure")
Expand Down

0 comments on commit c781234

Please sign in to comment.