From 140d66fad8b3d8bf1fe793fc73d5e0a9f55f16e3 Mon Sep 17 00:00:00 2001 From: Chris Gorman Date: Fri, 16 Feb 2024 08:56:29 -0700 Subject: [PATCH] Improved integer square root (#4403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hadrien Croubois Co-authored-by: Ernesto García --- contracts/utils/math/Math.sol | 132 ++++++++++++++++++++++++++-------- 1 file changed, 103 insertions(+), 29 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 0f9dce94f44..a8851f8ebef 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -385,38 +385,112 @@ library Math { * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded * towards zero. * - * Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11). + * This method is based on Newton's method for computing square roots; the algorithm is restricted to only + * using integer operations. */ function sqrt(uint256 a) internal pure returns (uint256) { - if (a == 0) { - return 0; - } - - // For our first guess, we get the biggest power of 2 which is smaller than the square root of the target. - // - // We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have - // `msb(a) <= a < 2*msb(a)`. This value can be written `msb(a)=2**k` with `k=log2(a)`. - // - // This can be rewritten `2**log2(a) <= a < 2**(log2(a) + 1)` - // → `sqrt(2**k) <= sqrt(a) < sqrt(2**(k+1))` - // → `2**(k/2) <= sqrt(a) < 2**((k+1)/2) <= 2**(k/2 + 1)` - // - // Consequently, `2**(log2(a) / 2)` is a good first approximation of `sqrt(a)` with at least 1 correct bit. - uint256 result = 1 << (log2(a) >> 1); - - // At this point `result` is an estimation with one bit of precision. We know the true value is a uint128, - // since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at - // every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision - // into the expected uint128 result. unchecked { - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - result = (result + a / result) >> 1; - return min(result, a / result); + // Take care of easy edge cases when a == 0 or a == 1 + if (a <= 1) { + return a; + } + + // In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a + // sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between + // the current value as `ε_n = | x_n - sqrt(a) |`. + // + // For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root + // of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is + // bigger than any uint256. + // + // By noticing that + // `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)` + // we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar + // to the msb function. + uint256 aa = a; + uint256 xn = 1; + + if (aa >= (1 << 128)) { + aa >>= 128; + xn <<= 64; + } + if (aa >= (1 << 64)) { + aa >>= 64; + xn <<= 32; + } + if (aa >= (1 << 32)) { + aa >>= 32; + xn <<= 16; + } + if (aa >= (1 << 16)) { + aa >>= 16; + xn <<= 8; + } + if (aa >= (1 << 8)) { + aa >>= 8; + xn <<= 4; + } + if (aa >= (1 << 4)) { + aa >>= 4; + xn <<= 2; + } + if (aa >= (1 << 2)) { + xn <<= 1; + } + + // We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1). + // + // We can refine our estimation by noticing that the the middle of that interval minimizes the error. + // If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2). + // This is going to be our x_0 (and ε_0) + xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2) + + // From here, Newton's method give us: + // x_{n+1} = (x_n + a / x_n) / 2 + // + // One should note that: + // x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a + // = ((x_n² + a) / (2 * x_n))² - a + // = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a + // = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²) + // = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²) + // = (x_n² - a)² / (2 * x_n)² + // = ((x_n² - a) / (2 * x_n))² + // ≥ 0 + // Which proves that for all n ≥ 1, sqrt(a) ≤ x_n + // + // This gives us the proof of quadratic convergence of the sequence: + // ε_{n+1} = | x_{n+1} - sqrt(a) | + // = | (x_n + a / x_n) / 2 - sqrt(a) | + // = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) | + // = | (x_n - sqrt(a))² / (2 * x_n) | + // = | ε_n² / (2 * x_n) | + // = ε_n² / | (2 * x_n) | + // + // For the first iteration, we have a special case where x_0 is known: + // ε_1 = ε_0² / | (2 * x_0) | + // ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2))) + // ≤ 2**(2*e-4) / (3 * 2**(e-1)) + // ≤ 2**(e-3) / 3 + // ≤ 2**(e-3-log2(3)) + // ≤ 2**(e-4.5) + // + // For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n: + // ε_{n+1} = ε_n² / | (2 * x_n) | + // ≤ (2**(e-k))² / (2 * 2**(e-1)) + // ≤ 2**(2*e-2*k) / 2**e + // ≤ 2**(e-2*k) + xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above + xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5 + xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9 + xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18 + xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36 + xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72 + + // Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision + // ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either + // sqrt(a) or sqrt(a) + 1. + return xn - SafeCast.toUint(xn > a / xn); } }