From a32c9a30e346a7cf82860fec746cdb95b97d6105 Mon Sep 17 00:00:00 2001 From: ronhuafeng Date: Fri, 9 Sep 2022 17:09:16 +0100 Subject: [PATCH] Support memory arrays in MerkleTree multiproof (#3493) --- contracts/mocks/MerkleProofWrapper.sol | 25 ++++++-- contracts/utils/cryptography/MerkleProof.sol | 66 +++++++++++++++++++- test/utils/cryptography/MerkleProof.test.js | 26 +++++++- 3 files changed, 109 insertions(+), 8 deletions(-) diff --git a/contracts/mocks/MerkleProofWrapper.sol b/contracts/mocks/MerkleProofWrapper.sol index 589db33a6c2..b74459dc890 100644 --- a/contracts/mocks/MerkleProofWrapper.sol +++ b/contracts/mocks/MerkleProofWrapper.sol @@ -30,19 +30,36 @@ contract MerkleProofWrapper { } function multiProofVerify( + bytes32[] memory proofs, + bool[] memory proofFlag, + bytes32 root, + bytes32[] memory leaves + ) public pure returns (bool) { + return MerkleProof.multiProofVerify(proofs, proofFlag, root, leaves); + } + + function multiProofVerifyCalldata( bytes32[] calldata proofs, bool[] calldata proofFlag, bytes32 root, - bytes32[] calldata leaves + bytes32[] memory leaves ) public pure returns (bool) { - return MerkleProof.multiProofVerify(proofs, proofFlag, root, leaves); + return MerkleProof.multiProofVerifyCalldata(proofs, proofFlag, root, leaves); } function processMultiProof( + bytes32[] memory proofs, + bool[] memory proofFlag, + bytes32[] memory leaves + ) public pure returns (bytes32) { + return MerkleProof.processMultiProof(proofs, proofFlag, leaves); + } + + function processMultiProofCalldata( bytes32[] calldata proofs, bool[] calldata proofFlag, - bytes32[] calldata leaves + bytes32[] memory leaves ) public pure returns (bytes32) { - return MerkleProof.processMultiProof(proofs, proofFlag, leaves); + return MerkleProof.processMultiProofCalldata(proofs, proofFlag, leaves); } } diff --git a/contracts/utils/cryptography/MerkleProof.sol b/contracts/utils/cryptography/MerkleProof.sol index 4c35d46aa8e..11d21a1fd84 100644 --- a/contracts/utils/cryptography/MerkleProof.sol +++ b/contracts/utils/cryptography/MerkleProof.sol @@ -81,12 +81,26 @@ library MerkleProof { * _Available since v4.7._ */ function multiProofVerify( + bytes32[] memory proof, + bool[] memory proofFlags, + bytes32 root, + bytes32[] memory leaves + ) internal pure returns (bool) { + return processMultiProof(proof, proofFlags, leaves) == root; + } + + /** + * @dev Calldata version of {multiProofVerify} + * + * _Available since v4.7._ + */ + function multiProofVerifyCalldata( bytes32[] calldata proof, bool[] calldata proofFlags, bytes32 root, - bytes32[] calldata leaves + bytes32[] memory leaves ) internal pure returns (bool) { - return processMultiProof(proof, proofFlags, leaves) == root; + return processMultiProofCalldata(proof, proofFlags, leaves) == root; } /** @@ -97,9 +111,55 @@ library MerkleProof { * _Available since v4.7._ */ function processMultiProof( + bytes32[] memory proof, + bool[] memory proofFlags, + bytes32[] memory leaves + ) internal pure returns (bytes32 merkleRoot) { + // This function rebuild the root hash by traversing the tree up from the leaves. The root is rebuilt by + // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the + // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of + // the merkle tree. + uint256 leavesLen = leaves.length; + uint256 totalHashes = proofFlags.length; + + // Check proof validity. + require(leavesLen + proof.length - 1 == totalHashes, "MerkleProof: invalid multiproof"); + + // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using + // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". + bytes32[] memory hashes = new bytes32[](totalHashes); + uint256 leafPos = 0; + uint256 hashPos = 0; + uint256 proofPos = 0; + // At each step, we compute the next hash using two values: + // - a value from the "main queue". If not all leaves have been consumed, we get the next leaf, otherwise we + // get the next hash. + // - depending on the flag, either another value for the "main queue" (merging branches) or an element from the + // `proof` array. + for (uint256 i = 0; i < totalHashes; i++) { + bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; + bytes32 b = proofFlags[i] ? leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++] : proof[proofPos++]; + hashes[i] = _hashPair(a, b); + } + + if (totalHashes > 0) { + return hashes[totalHashes - 1]; + } else if (leavesLen > 0) { + return leaves[0]; + } else { + return proof[0]; + } + } + + /** + * @dev Calldata version of {processMultiProof} + * + * _Available since v4.7._ + */ + function processMultiProofCalldata( bytes32[] calldata proof, bool[] calldata proofFlags, - bytes32[] calldata leaves + bytes32[] memory leaves ) internal pure returns (bytes32 merkleRoot) { // This function rebuild the root hash by traversing the tree up from the leaves. The root is rebuilt by // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the diff --git a/test/utils/cryptography/MerkleProof.test.js b/test/utils/cryptography/MerkleProof.test.js index 1bb3de7f12a..2d4aacdacdd 100644 --- a/test/utils/cryptography/MerkleProof.test.js +++ b/test/utils/cryptography/MerkleProof.test.js @@ -79,6 +79,7 @@ contract('MerkleProof', function (accounts) { const proofFlags = merkleTree.getProofFlags(proofLeaves, proof); expect(await this.merkleProof.multiProofVerify(proof, proofFlags, root, proofLeaves)).to.equal(true); + expect(await this.merkleProof.multiProofVerifyCalldata(proof, proofFlags, root, proofLeaves)).to.equal(true); }); it('returns false for an invalid Merkle multi proof', async function () { @@ -91,7 +92,10 @@ contract('MerkleProof', function (accounts) { const badProof = badMerkleTree.getMultiProof(badProofLeaves); const badProofFlags = badMerkleTree.getProofFlags(badProofLeaves, badProof); - expect(await this.merkleProof.multiProofVerify(badProof, badProofFlags, root, badProofLeaves)).to.equal(false); + expect(await this.merkleProof.multiProofVerify(badProof, badProofFlags, root, badProofLeaves)) + .to.equal(false); + expect(await this.merkleProof.multiProofVerifyCalldata(badProof, badProofFlags, root, badProofLeaves)) + .to.equal(false); }); it('revert with invalid multi proof #1', async function () { @@ -111,6 +115,15 @@ contract('MerkleProof', function (accounts) { ), 'MerkleProof: invalid multiproof', ); + await expectRevert( + this.merkleProof.multiProofVerifyCalldata( + [ leaves[1], fill, merkleTree.layers[1][1] ], + [ false, false, false ], + root, + [ leaves[0], badLeaf ], // A, E + ), + 'MerkleProof: invalid multiproof', + ); }); it('revert with invalid multi proof #2', async function () { @@ -130,6 +143,15 @@ contract('MerkleProof', function (accounts) { ), 'reverted with panic code 0x32', ); + await expectRevert( + this.merkleProof.multiProofVerifyCalldata( + [ leaves[1], fill, merkleTree.layers[1][1] ], + [ false, false, false, false ], + root, + [ badLeaf, leaves[0] ], // A, E + ), + 'reverted with panic code 0x32', + ); }); it('limit case: works for tree containing a single leaf', async function () { @@ -142,6 +164,7 @@ contract('MerkleProof', function (accounts) { const proofFlags = merkleTree.getProofFlags(proofLeaves, proof); expect(await this.merkleProof.multiProofVerify(proof, proofFlags, root, proofLeaves)).to.equal(true); + expect(await this.merkleProof.multiProofVerifyCalldata(proof, proofFlags, root, proofLeaves)).to.equal(true); }); it('limit case: can prove empty leaves', async function () { @@ -150,6 +173,7 @@ contract('MerkleProof', function (accounts) { const root = merkleTree.getRoot(); expect(await this.merkleProof.multiProofVerify([ root ], [], root, [])).to.equal(true); + expect(await this.merkleProof.multiProofVerifyCalldata([ root ], [], root, [])).to.equal(true); }); }); });