From 4fc9fd3efec867012deded6327fb259ea0082ab6 Mon Sep 17 00:00:00 2001 From: Troy Salem Date: Wed, 1 Jun 2022 16:36:42 -0400 Subject: [PATCH] Support more efficient merkle proofs through calldata (#3200) Co-authored-by: Hadrien Croubois Co-authored-by: Francisco Giordano --- CHANGELOG.md | 1 + contracts/mocks/MerkleProofWrapper.sol | 24 ++++++--- contracts/utils/cryptography/MerkleProof.sol | 55 ++++++++++++++------ test/utils/cryptography/MerkleProof.test.js | 12 +++-- 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 985bebf8887..f897e66de28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ * `EnumerableMap`: add new `Bytes32ToUintMap` map type. ([#3416](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3416)) * `SafeCast`: add support for many more types, using procedural code generation. ([#3245](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3245)) * `MerkleProof`: add `multiProofVerify` to prove multiple values are part of a Merkle tree. ([#3276](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3276)) + * `MerkleProof`: add calldata versions of the functions to avoid copying input arrays to memory and save gas. ([#3200](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3200)) * `ERC721`, `ERC1155`: simplified revert reasons. ([#3254](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3254)) * `ERC721`: removed redundant require statement. ([#3434](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3434)) * `PaymentSplitter`: add `releasable` getters. ([#3350](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3350)) diff --git a/contracts/mocks/MerkleProofWrapper.sol b/contracts/mocks/MerkleProofWrapper.sol index a58e69f3b0e..519222613e2 100644 --- a/contracts/mocks/MerkleProofWrapper.sol +++ b/contracts/mocks/MerkleProofWrapper.sol @@ -13,23 +13,35 @@ contract MerkleProofWrapper { return MerkleProof.verify(proof, root, leaf); } + function verifyCalldata( + bytes32[] calldata proof, + bytes32 root, + bytes32 leaf + ) public pure returns (bool) { + return MerkleProof.verifyCalldata(proof, root, leaf); + } + function processProof(bytes32[] memory proof, bytes32 leaf) public pure returns (bytes32) { return MerkleProof.processProof(proof, leaf); } + function processProofCalldata(bytes32[] calldata proof, bytes32 leaf) public pure returns (bytes32) { + return MerkleProof.processProofCalldata(proof, leaf); + } + function multiProofVerify( bytes32 root, - bytes32[] memory leafs, - bytes32[] memory proofs, - bool[] memory proofFlag + bytes32[] calldata leafs, + bytes32[] calldata proofs, + bool[] calldata proofFlag ) public pure returns (bool) { return MerkleProof.multiProofVerify(root, leafs, proofs, proofFlag); } function processMultiProof( - bytes32[] memory leafs, - bytes32[] memory proofs, - bool[] memory proofFlag + bytes32[] calldata leafs, + bytes32[] calldata proofs, + bool[] calldata proofFlag ) public pure returns (bytes32) { return MerkleProof.processMultiProof(leafs, proofs, proofFlag); } diff --git a/contracts/utils/cryptography/MerkleProof.sol b/contracts/utils/cryptography/MerkleProof.sol index 21d64c00ab1..5b225757d41 100644 --- a/contracts/utils/cryptography/MerkleProof.sol +++ b/contracts/utils/cryptography/MerkleProof.sol @@ -32,6 +32,19 @@ library MerkleProof { return processProof(proof, leaf) == root; } + /** + * @dev Calldata version of {verify} + * + * _Available since v4.7._ + */ + function verifyCalldata( + bytes32[] calldata proof, + bytes32 root, + bytes32 leaf + ) internal pure returns (bool) { + return processProofCalldata(proof, leaf) == root; + } + /** * @dev Returns the rebuilt hash obtained by traversing a Merkle tree up * from `leaf` using `proof`. A `proof` is valid if and only if the rebuilt @@ -48,6 +61,19 @@ library MerkleProof { return computedHash; } + /** + * @dev Calldata version of {processProof} + * + * _Available since v4.7._ + */ + function processProofCalldata(bytes32[] calldata proof, bytes32 leaf) internal pure returns (bytes32) { + bytes32 computedHash = leaf; + for (uint256 i = 0; i < proof.length; i++) { + computedHash = _hashPair(computedHash, proof[i]); + } + return computedHash; + } + /** * @dev Returns true if a `leafs` can be proved to be a part of a Merkle tree * defined by `root`. For this, `proofs` for each leaf must be provided, containing @@ -58,11 +84,11 @@ library MerkleProof { */ function multiProofVerify( bytes32 root, - bytes32[] memory leafs, - bytes32[] memory proofs, - bool[] memory proofFlag + bytes32[] calldata leaves, + bytes32[] calldata proofs, + bool[] calldata proofFlag ) internal pure returns (bool) { - return processMultiProof(leafs, proofs, proofFlag) == root; + return processMultiProof(leaves, proofs, proofFlag) == root; } /** @@ -73,20 +99,19 @@ library MerkleProof { * _Available since v4.7._ */ function processMultiProof( - bytes32[] memory leafs, - bytes32[] memory proofs, - bool[] memory proofFlag + bytes32[] calldata leaves, + bytes32[] calldata proofs, + bool[] calldata proofFlag ) internal pure returns (bytes32 merkleRoot) { // This function rebuild the root hash by traversing the tree up from the leaves. The root is rebuilt by - // consuming and producing values on a queue. The queue starts with the `leafs` array, then goes onto the + // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // the merkle tree. - uint256 leafsLen = leafs.length; - uint256 proofsLen = proofs.length; + uint256 leavesLen = leaves.length; uint256 totalHashes = proofFlag.length; // Check proof validity. - require(leafsLen + proofsLen - 1 == totalHashes, "MerkleProof: invalid multiproof"); + require(leavesLen + proofs.length - 1 == totalHashes, "MerkleProof: invalid multiproof"); // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". @@ -100,15 +125,15 @@ library MerkleProof { // - depending on the flag, either another value for the "main queue" (merging branches) or an element from the // `proofs` array. for (uint256 i = 0; i < totalHashes; i++) { - bytes32 a = leafPos < leafsLen ? leafs[leafPos++] : hashes[hashPos++]; - bytes32 b = proofFlag[i] ? leafPos < leafsLen ? leafs[leafPos++] : hashes[hashPos++] : proofs[proofPos++]; + bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; + bytes32 b = proofFlag[i] ? leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++] : proofs[proofPos++]; hashes[i] = _hashPair(a, b); } if (totalHashes > 0) { return hashes[totalHashes - 1]; - } else if (leafsLen > 0) { - return leafs[0]; + } else if (leavesLen > 0) { + return leaves[0]; } else { return proofs[0]; } diff --git a/test/utils/cryptography/MerkleProof.test.js b/test/utils/cryptography/MerkleProof.test.js index c3adfb34c3d..eeef6f1b530 100644 --- a/test/utils/cryptography/MerkleProof.test.js +++ b/test/utils/cryptography/MerkleProof.test.js @@ -25,12 +25,14 @@ contract('MerkleProof', function (accounts) { const proof = merkleTree.getHexProof(leaf); expect(await this.merkleProof.verify(proof, root, leaf)).to.equal(true); + expect(await this.merkleProof.verifyCalldata(proof, root, leaf)).to.equal(true); // For demonstration, it is also possible to create valid proofs for certain 64-byte values *not* in elements: const noSuchLeaf = keccak256( Buffer.concat([keccak256(elements[0]), keccak256(elements[1])].sort(Buffer.compare)), ); expect(await this.merkleProof.verify(proof.slice(1), root, noSuchLeaf)).to.equal(true); + expect(await this.merkleProof.verifyCalldata(proof.slice(1), root, noSuchLeaf)).to.equal(true); }); it('returns false for an invalid Merkle proof', async function () { @@ -47,6 +49,7 @@ contract('MerkleProof', function (accounts) { const badProof = badMerkleTree.getHexProof(badElements[0]); expect(await this.merkleProof.verify(badProof, correctRoot, correctLeaf)).to.equal(false); + expect(await this.merkleProof.verifyCalldata(badProof, correctRoot, correctLeaf)).to.equal(false); }); it('returns false for a Merkle proof of invalid length', async function () { @@ -61,6 +64,7 @@ contract('MerkleProof', function (accounts) { const badProof = proof.slice(0, proof.length - 5); expect(await this.merkleProof.verify(badProof, root, leaf)).to.equal(false); + expect(await this.merkleProof.verifyCalldata(badProof, root, leaf)).to.equal(false); }); }); @@ -93,7 +97,7 @@ contract('MerkleProof', function (accounts) { it('revert with invalid multi proof #1', async function () { const fill = Buffer.alloc(32); // This could be anything, we are reconstructing a fake branch const leaves = ['a', 'b', 'c', 'd'].map(keccak256).sort(Buffer.compare); - const badLeave = keccak256('e'); + const badLeaf = keccak256('e'); const merkleTree = new MerkleTree(leaves, keccak256, { sort: true }); const root = merkleTree.getRoot(); @@ -101,7 +105,7 @@ contract('MerkleProof', function (accounts) { await expectRevert( this.merkleProof.multiProofVerify( root, - [ leaves[0], badLeave ], // A, E + [ leaves[0], badLeaf ], // A, E [ leaves[1], fill, merkleTree.layers[1][1] ], [ false, false, false ], ), @@ -112,7 +116,7 @@ contract('MerkleProof', function (accounts) { it('revert with invalid multi proof #2', async function () { const fill = Buffer.alloc(32); // This could be anything, we are reconstructing a fake branch const leaves = ['a', 'b', 'c', 'd'].map(keccak256).sort(Buffer.compare); - const badLeave = keccak256('e'); + const badLeaf = keccak256('e'); const merkleTree = new MerkleTree(leaves, keccak256, { sort: true }); const root = merkleTree.getRoot(); @@ -120,7 +124,7 @@ contract('MerkleProof', function (accounts) { await expectRevert( this.merkleProof.multiProofVerify( root, - [ badLeave, leaves[0] ], // A, E + [ badLeaf, leaves[0] ], // A, E [ leaves[1], fill, merkleTree.layers[1][1] ], [ false, false, false, false ], ),