Skip to content

Commit

Permalink
Add bytes memory version of Math.modExp (#4893)
Browse files Browse the repository at this point in the history
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
  • Loading branch information
ernestognw and Amxx committed Feb 14, 2024
1 parent ae1bafc commit 4e7e6e5
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .changeset/shiny-poets-whisper.md
Expand Up @@ -2,4 +2,4 @@
'openzeppelin-solidity': minor
---

`Math`: Add `modExp` function that exposes the `EIP-198` precompile.
`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions.
58 changes: 52 additions & 6 deletions contracts/utils/math/Math.sol
Expand Up @@ -3,7 +3,6 @@

pragma solidity ^0.8.20;

import {Address} from "../Address.sol";
import {Panic} from "../Panic.sol";
import {SafeCast} from "./SafeCast.sol";

Expand Down Expand Up @@ -289,11 +288,7 @@ library Math {
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
(bool success, uint256 result) = tryModExp(b, e, m);
if (!success) {
if (m == 0) {
Panic.panic(Panic.DIVISION_BY_ZERO);
} else {
revert Address.FailedInnerCall();
}
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}
Expand Down Expand Up @@ -335,6 +330,57 @@ library Math {
}
}

/**
* @dev Variant of {modExp} that supports inputs of arbitrary length.
*/
function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
(bool success, bytes memory result) = tryModExp(b, e, m);
if (!success) {
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}

/**
* @dev Variant of {tryModExp} that supports inputs of arbitrary length.
*/
function tryModExp(
bytes memory b,
bytes memory e,
bytes memory m
) internal view returns (bool success, bytes memory result) {
if (_zeroBytes(m)) return (false, new bytes(0));

uint256 mLen = m.length;

// Encode call args in result and move the free memory pointer
result = abi.encodePacked(b.length, e.length, mLen, b, e, m);

/// @solidity memory-safe-assembly
assembly {
let dataPtr := add(result, 0x20)
// Write result on top of args to avoid allocating extra memory.
success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
// Overwrite the length.
// result.length > returndatasize() is guaranteed because returndatasize() == m.length
mstore(result, mLen)
// Set the memory pointer after the returned data.
mstore(0x40, add(dataPtr, mLen))
}
}

/**
* @dev Returns whether the provided byte array is zero.
*/
function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
for (uint256 i = 0; i < byteArray.length; ++i) {
if (byteArray[i] != 0) {
return false;
}
}
return true;
}

/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
Expand Down
23 changes: 23 additions & 0 deletions test/helpers/math.js
Expand Up @@ -3,8 +3,31 @@ const max = (...values) => values.slice(1).reduce((x, y) => (x > y ? x : y), val
const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0));
const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0));

// Computes modexp without BigInt overflow for large numbers
function modExp(b, e, m) {
let result = 1n;

// If e is a power of two, modexp can be calculated as:
// for (let result = b, i = 0; i < log2(e); i++) result = modexp(result, 2, m)
//
// Given any natural number can be written in terms of powers of 2 (i.e. binary)
// then modexp can be calculated for any e, by multiplying b**i for all i where
// binary(e)[i] is 1 (i.e. a power of two).
for (let base = b % m; e > 0n; base = base ** 2n % m) {
// Least significant bit is 1
if (e % 2n == 1n) {
result = (result * base) % m;
}

e /= 2n; // Binary pop
}

return result;
}

module.exports = {
min,
max,
sum,
modExp,
};
27 changes: 27 additions & 0 deletions test/utils/math/Math.t.sol
Expand Up @@ -226,6 +226,33 @@ contract MathTest is Test {
}
}

function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
if (m == 0) {
vm.expectRevert(stdError.divisionError);
}
bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
assertEq(result.length, 0x20);
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(res, _nativeModExp(b, e, m));
}

function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
(bool success, bytes memory result) = Math.tryModExp(
abi.encodePacked(b),
abi.encodePacked(e),
abi.encodePacked(m)
);
if (success) {
assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(res, _nativeModExp(b, e, m));
} else {
assertEq(result.length, 0);
}
}

function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
if (m == 1) return 0;
uint256 r = 1;
Expand Down
106 changes: 77 additions & 29 deletions test/utils/math/Math.test.js
Expand Up @@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');

const { Rounding } = require('../../helpers/enums');
const { min, max } = require('../../helpers/math');
const { min, max, modExp } = require('../../helpers/math');
const { generators } = require('../../helpers/random');
const { range } = require('../../../scripts/helpers');
const { product } = require('../../helpers/iterate');

const RoundingDown = [Rounding.Floor, Rounding.Trunc];
const RoundingUp = [Rounding.Ceil, Rounding.Expand];

const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width));
const uint256 = value => ethers.Typed.uint256(value);
bytes.zero = '0x';
uint256.zero = 0n;

async function testCommutative(fn, lhs, rhs, expected, ...extra) {
expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
Expand Down Expand Up @@ -141,24 +148,6 @@ describe('Math', function () {
});
});

describe('tryModExp', function () {
it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;

expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
});

it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;

expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
});
});

describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
Expand Down Expand Up @@ -354,20 +343,79 @@ describe('Math', function () {
});

describe('modExp', function () {
it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
for (const [name, type] of Object.entries({ uint256, bytes })) {
describe(`with ${name} inputs`, function () {
it('is correctly calculating modulus', async function () {
const b = 3n;
const e = 200n;
const m = 50n;

expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
});

expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
it('is correctly reverting when modulus is zero', async function () {
const b = 3n;
const e = 200n;
const m = 0n;

await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic(
PANIC_CODES.DIVISION_BY_ZERO,
);
});
});
}

describe('with large bytes inputs', function () {
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
)) {
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
const mLength = ethers.dataLength(ethers.toBeHex(m));

expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
});
}
});
});

describe('tryModExp', function () {
for (const [name, type] of Object.entries({ uint256, bytes })) {
describe(`with ${name} inputs`, function () {
it('is correctly calculating modulus', async function () {
const b = 3n;
const e = 200n;
const m = 50n;

expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
});

it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
it('is correctly reverting when modulus is zero', async function () {
const b = 3n;
const e = 200n;
const m = 0n;

await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
});
});
}

describe('with large bytes inputs', function () {
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
)) {
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
const mLength = ethers.dataLength(ethers.toBeHex(m));

expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
true,
bytes(modExp(b, e, m), mLength).value,
]);
});
}
});
});

Expand Down

0 comments on commit 4e7e6e5

Please sign in to comment.