Skip to content

Commit

Permalink
Support custom node hash in SimpleMerkleTree (#39)
Browse files Browse the repository at this point in the history
Co-authored-by: ernestognw <ernestognw@gmail.com>
  • Loading branch information
Amxx and ernestognw committed Mar 4, 2024
1 parent 6ab2cfb commit 29f611e
Show file tree
Hide file tree
Showing 9 changed files with 399 additions and 99 deletions.
28 changes: 13 additions & 15 deletions src/core.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import { keccak256 } from '@ethersproject/keccak256';
import { BytesLike, HexString, toHex, toBytes, concat, compare } from './bytes';
import { BytesLike, HexString, toHex, toBytes, compare } from './bytes';
import { NodeHash, standardNodeHash } from './hashes';
import { invariant, throwError, validateArgument } from './utils/errors';

const hashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare)));

const leftChildIndex = (i: number) => 2 * i + 1;
const rightChildIndex = (i: number) => 2 * i + 2;
const parentIndex = (i: number) => (i > 0 ? Math.floor((i - 1) / 2) : throwError('Root has no parent'));
Expand All @@ -18,7 +16,7 @@ const checkLeafNode = (tree: unknown[], i: number) => void (isLeafNode(tree, i)
const checkValidMerkleNode = (node: BytesLike) =>
void (isValidMerkleNode(node) || throwError('Merkle tree nodes must be Uint8Array of length 32'));

export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
export function makeMerkleTree(leaves: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString[] {
leaves.forEach(checkValidMerkleNode);

validateArgument(leaves.length !== 0, 'Expected non-zero number of leaves');
Expand All @@ -29,7 +27,7 @@ export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
tree[tree.length - 1 - i] = toHex(leaf);
}
for (let i = tree.length - 1 - leaves.length; i >= 0; i--) {
tree[i] = hashPair(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
tree[i] = nodeHash(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
}

return tree;
Expand All @@ -46,11 +44,11 @@ export function getProof(tree: BytesLike[], index: number): HexString[] {
return proof;
}

export function processProof(leaf: BytesLike, proof: BytesLike[]): HexString {
export function processProof(leaf: BytesLike, proof: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString {
checkValidMerkleNode(leaf);
proof.forEach(checkValidMerkleNode);

return toHex(proof.reduce(hashPair, leaf));
return toHex(proof.reduce(nodeHash, leaf));
}

export interface MultiProof<T, L = T> {
Expand All @@ -68,7 +66,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
'Cannot prove duplicated index',
);

const stack = indices.concat(); // copy
const stack = Array.from(indices); // copy
const proof = [];
const proofFlags = [];

Expand Down Expand Up @@ -98,7 +96,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
};
}

export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString {
export function processMultiProof(multiproof: MultiProof<BytesLike>, nodeHash: NodeHash = standardNodeHash): HexString {
multiproof.leaves.forEach(checkValidMerkleNode);
multiproof.proof.forEach(checkValidMerkleNode);

Expand All @@ -111,22 +109,22 @@ export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString
'Provided leaves and multiproof are not compatible',
);

const stack = multiproof.leaves.concat(); // copy
const proof = multiproof.proof.concat(); // copy
const stack = Array.from(multiproof.leaves); // copy
const proof = Array.from(multiproof.proof); // copy

for (const flag of multiproof.proofFlags) {
const a = stack.shift();
const b = flag ? stack.shift() : proof.shift();
invariant(a !== undefined && b !== undefined);
stack.push(hashPair(a, b));
stack.push(nodeHash(a, b));
}

invariant(stack.length + proof.length === 1);

return toHex(stack.pop() ?? proof.shift()!);
}

export function isValidMerkleTree(tree: BytesLike[]): boolean {
export function isValidMerkleTree(tree: BytesLike[], nodeHash: NodeHash = standardNodeHash): boolean {
for (const [i, node] of tree.entries()) {
if (!isValidMerkleNode(node)) {
return false;
Expand All @@ -139,7 +137,7 @@ export function isValidMerkleTree(tree: BytesLike[]): boolean {
if (l < tree.length) {
return false;
}
} else if (compare(node, hashPair(tree[l]!, tree[r]!))) {
} else if (compare(node, nodeHash(tree[l]!, tree[r]!))) {
return false;
}
}
Expand Down
14 changes: 14 additions & 0 deletions src/hashes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { defaultAbiCoder } from '@ethersproject/abi';
import { keccak256 } from '@ethersproject/keccak256';
import { BytesLike, HexString, concat, compare } from './bytes';

export type LeafHash<T> = (leaf: T) => HexString;
export type NodeHash = (left: BytesLike, right: BytesLike) => HexString;

export function standardLeafHash<T extends any[]>(types: string[], value: T): HexString {
return keccak256(keccak256(defaultAbiCoder.encode(types, value)));
}

export function standardNodeHash(a: BytesLike, b: BytesLike): HexString {
return keccak256(concat([a, b].sort(compare)));
}
18 changes: 12 additions & 6 deletions src/merkletree.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
} from './core';

import { MerkleTreeOptions, defaultOptions } from './options';
import { LeafHash, NodeHash } from './hashes';
import { validateArgument, invariant } from './utils/errors';

export interface MerkleTreeData<T> {
Expand Down Expand Up @@ -40,7 +41,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
protected constructor(
protected readonly tree: HexString[],
protected readonly values: MerkleTreeData<T>['values'],
public readonly leafHash: MerkleTree<T>['leafHash'],
public readonly leafHash: LeafHash<T>,
protected readonly nodeHash?: NodeHash,
) {
validateArgument(
values.every(({ value }) => typeof value != 'number'),
Expand All @@ -52,7 +54,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
protected static prepare<T>(
values: T[],
options: MerkleTreeOptions = {},
leafHash: MerkleTree<T>['leafHash'],
leafHash: LeafHash<T>,
nodeHash?: NodeHash,
): [tree: HexString[], indexedValues: MerkleTreeData<T>['values']] {
const sortLeaves = options.sortLeaves ?? defaultOptions.sortLeaves;
const hashedValues = values.map((value, valueIndex) => ({
Expand All @@ -65,7 +68,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
hashedValues.sort((a, b) => compare(a.hash, b.hash));
}

const tree = makeMerkleTree(hashedValues.map(v => v.hash));
const tree = makeMerkleTree(
hashedValues.map(v => v.hash),
nodeHash,
);

const indexedValues = values.map(value => ({ value, treeIndex: 0 }));
for (const [leafIndex, { valueIndex }] of hashedValues.entries()) {
Expand Down Expand Up @@ -93,7 +99,7 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {

validate(): void {
this.values.forEach((_, i) => this._validateValueAt(i));
invariant(isValidMerkleTree(this.tree), 'Merkle tree is invalid');
invariant(isValidMerkleTree(this.tree, this.nodeHash), 'Merkle tree is invalid');
}

leafLookup(leaf: T): number {
Expand Down Expand Up @@ -171,10 +177,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
}

private _verify(leafHash: BytesLike, proof: BytesLike[]): boolean {
return this.root === processProof(leafHash, proof);
return this.root === processProof(leafHash, proof, this.nodeHash);
}

private _verifyMultiProof(multiproof: MultiProof<BytesLike>): boolean {
return this.root === processMultiProof(multiproof);
return this.root === processMultiProof(multiproof, this.nodeHash);
}
}
137 changes: 91 additions & 46 deletions src/simple.test.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
import { test, testProp, fc } from '@fast-check/ava';
import { HashZero as zero } from '@ethersproject/constants';
import { keccak256 } from '@ethersproject/keccak256';
import { SimpleMerkleTree } from './simple';
import { BytesLike, HexString, concat, compare } from './bytes';

const reverseNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse()));
const otherNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(reverseNodeHash(a, b)); // double hash

import { toHex } from './bytes';
import { InvalidArgumentError, InvariantError } from './utils/errors';

const leaf = fc.uint8Array({ minLength: 32, maxLength: 32 }).map(toHex);
const leaves = fc.array(leaf, { minLength: 1 });
const options = fc.record({ sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()) });
const options = fc.record({
sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()),
nodeHash: fc.oneof(fc.constant(undefined), fc.constant(reverseNodeHash)),
});

const tree = fc.tuple(leaves, options).map(([leaves, options]) => SimpleMerkleTree.of(leaves, options));
const tree = fc
.tuple(leaves, options)
.chain(([leaves, options]) => fc.tuple(fc.constant(SimpleMerkleTree.of(leaves, options)), fc.constant(options)));
const treeAndLeaf = fc.tuple(leaves, options).chain(([leaves, options]) =>
fc.tuple(
fc.constant(SimpleMerkleTree.of(leaves, options)),
fc.constant(options),
fc.nat({ max: leaves.length - 1 }).map(index => ({ value: leaves[index]!, index })),
),
);
const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>
fc.tuple(
fc.constant(SimpleMerkleTree.of(leaves, options)),
fc.constant(options),
fc
.uniqueArray(fc.nat({ max: leaves.length - 1 }))
.map(indices => indices.map(index => ({ value: leaves[index]!, index }))),
Expand All @@ -26,48 +39,64 @@ const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>

fc.configureGlobal({ numRuns: process.env.CI ? 10000 : 100 });

testProp('generates a valid tree', [tree], (t, tree) => {
testProp('generates a valid tree', [tree], (t, [tree]) => {
t.notThrows(() => tree.validate());
});

testProp('generates valid single proofs for all leaves', [treeAndLeaf], (t, [tree, { value: leaf, index }]) => {
const proof1 = tree.getProof(index);
const proof2 = tree.getProof(leaf);

t.deepEqual(proof1, proof2);
t.true(tree.verify(index, proof1));
t.true(tree.verify(leaf, proof1));
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1));
});
testProp(
'generates valid single proofs for all leaves',
[treeAndLeaf],
(t, [tree, options, { value: leaf, index }]) => {
const proof1 = tree.getProof(index);
const proof2 = tree.getProof(leaf);

t.deepEqual(proof1, proof2);
t.true(tree.verify(index, proof1));
t.true(tree.verify(leaf, proof1));
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1, options.nodeHash));
},
);

testProp('rejects invalid proofs', [treeAndLeaf, tree], (t, [tree, { value: leaf }], otherTree) => {
const proof = tree.getProof(leaf);
t.false(otherTree.verify(leaf, proof));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof));
});
testProp(
'rejects invalid proofs',
[treeAndLeaf, tree],
(t, [tree, options, { value: leaf }], [otherTree, otherOptions]) => {
const proof = tree.getProof(leaf);
t.false(otherTree.verify(leaf, proof));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, options.nodeHash));
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, otherOptions.nodeHash));
},
);

testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, indices]) => {
testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, options, indices]) => {
const proof1 = tree.getMultiProof(indices.map(e => e.index));
const proof2 = tree.getMultiProof(indices.map(e => e.value));

t.deepEqual(proof1, proof2);
t.true(tree.verifyMultiProof(proof1));
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1));
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1, options.nodeHash));
});

testProp('rejects invalid multiproofs', [treeAndLeaves, tree], (t, [tree, indices], otherTree) => {
const multiProof = tree.getMultiProof(indices.map(e => e.index));

t.false(otherTree.verifyMultiProof(multiProof));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof));
});
testProp(
'rejects invalid multiproofs',
[treeAndLeaves, tree],
(t, [tree, options, indices], [otherTree, otherOptions]) => {
const multiProof = tree.getMultiProof(indices.map(e => e.index));

t.false(otherTree.verifyMultiProof(multiProof));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, options.nodeHash));
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, otherOptions.nodeHash));
},
);

testProp(
'renders tree representation',
[leaves],
(t, leaves) => {
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).render());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).render());
},
{ numRuns: 1, seed: 0 },
);
Expand All @@ -78,24 +107,34 @@ testProp(
(t, leaves) => {
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).dump());
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).dump());
},
{ numRuns: 1, seed: 0 },
);

testProp('dump and load', [tree], (t, tree) => {
const recoveredTree = SimpleMerkleTree.load(tree.dump());
recoveredTree.validate();
testProp('dump and load', [tree], (t, [tree, options]) => {
const dump = tree.dump();
const recoveredTree = SimpleMerkleTree.load(dump, options.nodeHash);
recoveredTree.validate(); // already done in load

t.is(dump.hash, options.nodeHash ? 'custom' : undefined);
t.is(tree.root, recoveredTree.root);
t.is(tree.render(), recoveredTree.render());
t.deepEqual(tree.entries(), recoveredTree.entries());
t.deepEqual(tree.dump(), recoveredTree.dump());
});

testProp('reject out of bounds value index', [tree], (t, tree) => {
testProp('reject out of bounds value index', [tree], (t, [tree]) => {
t.throws(() => tree.getProof(-1), new InvalidArgumentError('Index out of bounds'));
});

// We need at least 2 leaves for internal node hashing to come into play
testProp('reject loading dump with wrong node hash', [fc.array(leaf, { minLength: 2 })], (t, leaves) => {
const dump = SimpleMerkleTree.of(leaves, { nodeHash: reverseNodeHash }).dump();
t.throws(() => SimpleMerkleTree.load(dump, otherNodeHash), new InvariantError('Merkle tree is invalid'));
});

test('reject invalid leaf size', t => {
const invalidLeaf = '0x000000000000000000000000000000000000000000000000000000000000000000';
t.throws(() => SimpleMerkleTree.of([invalidLeaf]), {
Expand All @@ -116,22 +155,28 @@ test('reject unrecognized tree dump', t => {
});

test('reject malformed tree dump', t => {
const loadedTree1 = SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero],
values: [
{
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
treeIndex: 0,
},
],
});
t.throws(() => loadedTree1.getProof(0), new InvariantError('Merkle tree does not contain the expected value'));
t.throws(
() =>
SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero],
values: [
{
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
treeIndex: 0,
},
],
}),
new InvariantError('Merkle tree does not contain the expected value'),
);

const loadedTree2 = SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero, zero, zero],
values: [{ value: zero, treeIndex: 2 }],
});
t.throws(() => loadedTree2.getProof(0), new InvariantError('Unable to prove value'));
t.throws(
() =>
SimpleMerkleTree.load({
format: 'simple-v1',
tree: [zero, zero, zero],
values: [{ value: zero, treeIndex: 2 }],
}),
new InvariantError('Merkle tree is invalid'),
);
});

0 comments on commit 29f611e

Please sign in to comment.