Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved integer square root. #4403

Merged
merged 26 commits into from Feb 16, 2024
Merged
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 66 additions & 29 deletions contracts/utils/math/Math.sol
Expand Up @@ -339,38 +339,75 @@ library Math {
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
*
* Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
* This method is based on Newton's method for computing square roots; the algorithm is restricted to only
* using integer operations.
*/
function sqrt(uint256 a) internal pure returns (uint256) {
if (a == 0) {
return 0;
}

// For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
//
// We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
// `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`.
//
// This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)`
// → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))`
// → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)`
//
// Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit.
uint256 result = 1 << (log2(a) >> 1);

// At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
// since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
// every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
// into the expected uint128 result.
unchecked {
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
return min(result, a / result);
// Take care of easy edge cases when a == 0 or a == 1
if (a <= 1) {
return a;
}
chgorman marked this conversation as resolved.
Show resolved Hide resolved

uint256 aAux = a;
uint256 result = 1;

// For our first guess, we get the biggest power of 2 which is smaller
// than the square root of the target (i.e. result = 2**n <= sqrt(a) < 2**(n+1)).
// We know if result is 2**e + c, then e is bounded to 127 because (2**128)**2 = 2**256,
// which is bigger than any uint256. If result >= 2**e, then sqrt(a) <= 2**e-1.
// We approximate the result by adding 2**e-1 and subtracting 2**e for each e such that
// sqrt(a) < 2**e by cutting e/2 on each iteration until e = 2.
if (aAux >= (1 << 128)) {
aAux >>= 128;
result <<= 64; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 64)) {
aAux >>= 64;
result <<= 32; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 32)) {
aAux >>= 32;
result <<= 16; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 16)) {
aAux >>= 16;
result <<= 8; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 8)) {
aAux >>= 8;
result <<= 4; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 4)) {
aAux >>= 4;
result <<= 2; // sqrt(a) >= 2**(e/2)
}
if (aAux >= (1 << 2)) {
result <<= 1;
}

// We know that result <= sqrt(a) < 2 * result. We can use the fact that
// 2**e <= sqrt(a) to improve the estimation by computing the arithmetic
// mean between the current estimation and the next one (e = 1).
result = (3 * result) >> 1;

// We define the error as ε = result - sqrt(a). Then we know that
// result = 2**e−1 + 2**e−2, and therefore ε0 = 2**e−1 + 2**e−2 - sqrt(n),
Copy link
Collaborator

@Amxx Amxx Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a difference of notation between this and the PDF.

Here we do result = 2**e <= sqrt(a) < 2**(e+1))
The PDF does 2**(e−1) ≤ √n < 2**e (B.2)

This causes much confusion.

// leaving ε_0 <= 2**e−2. We also see ε_{+1} == ε**2 / 2x <= ε**2 / 2 * sqrt(a)
// as shown in Walter Rudin. Principles of Mathematical Analysis.
// 3rd ed. McGraw-Hill New York, 1976. Exercise 3.16 (b)

result = (result + a / result) >> 1; // ε_1 := result - sqrt(a) <= 2**(e-4.5)
result = (result + a / result) >> 1; // ε_2 := result - sqrt(a) <= 2**(e-9)
result = (result + a / result) >> 1; // ε_3 := result - sqrt(a) <= 2**(e-18)
result = (result + a / result) >> 1; // ε_4 := result - sqrt(a) <= 2**(e-36)
result = (result + a / result) >> 1; // ε_5 := result - sqrt(a) <= 2**(e-72)
result = (result + a / result) >> 1; // ε_6 := result - sqrt(a) <= 2**(e-144)

// After 6 iterations, the precision of e is already above 128 (i.e. 144). Meaning that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precision of e

We computed e at step 0 (using all the ifs). Here we are dealing with the precision of x_n

// ε6 <= 1. And given we're operating on integers, then we can ensure that result is
// either sqrt(a) or sqrt(a) + 1.
return result - SafeCast.toUint(result > a / result);
}
}

Expand Down