diff --git a/docs/source/merkleproof.rst b/docs/source/merkleproof.rst new file mode 100644 index 000000000..4dcc94e47 --- /dev/null +++ b/docs/source/merkleproof.rst @@ -0,0 +1,9 @@ +MerkleProof +============================================= + +Merkle proof verification for leaves of a Merkle tree. + +verifyProof(bytes _proof, bytes32 _root, bytes32 _leaf) internal constant returns (bool) +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Verifies a Merkle proof proving the existence of a leaf in a Merkle tree. Assumes that each pair of leaves and each pair of pre-images is sorted. diff --git a/package.json b/package.json index 3f03eeb18..28c21f386 100644 --- a/package.json +++ b/package.json @@ -38,6 +38,7 @@ "babel-register": "^6.23.0", "coveralls": "^2.13.1", "ethereumjs-testrpc": "^3.0.2", + "ethereumjs-util": "^5.1.2", "mocha-lcov-reporter": "^1.3.0", "solidity-coverage": "^0.1.0", "truffle": "3.2.2" diff --git a/test/MerkleProof.js b/test/MerkleProof.js index 8ff58ec34..f2e096e09 100644 --- a/test/MerkleProof.js +++ b/test/MerkleProof.js @@ -1,54 +1,43 @@ -var MerkleProofMock = artifacts.require("./helpers/MerkleProofMock.sol"); +var MerkleProof = artifacts.require("./MerkleProof.sol"); + +import { sha3 } from "ethereumjs-util"; +import MerkleTree from "./helpers/merkleTree.js"; contract('MerkleProof', function(accounts) { - let merkleProof; + let merkleProof; - before(async function() { - merkleProof = await MerkleProofMock.new(); + before(async function() { + merkleProof = await MerkleProof.new(); + }); + + describe("verifyProof", function() { + it("should return true for a valid Merkle proof", async function() { + const elements = ["a", "b", "c", "d"].map(el => sha3(el)); + const merkleTree = new MerkleTree(elements); + + const root = merkleTree.getHexRoot(); + + const proof = merkleTree.getHexProof(elements[0]); + + const leaf = merkleTree.bufToHex(elements[0]); + + const result = await merkleProof.verifyProof(proof, root, leaf); + assert.isOk(result, "verifyProof did not return true for a valid proof"); }); - describe("verifyProof", function() { - it("should return true for a valid Merkle proof given even number of leaves", async function() { - // const elements = ["a", "b", "c", "d"].map(el => sha3(el)); - // const merkleTree = new MerkleTree(elements); + it("should return false for an invalid Merkle proof", async function() { + const elements = ["a", "b", "c"].map(el => sha3(el)); + const merkleTree = new MerkleTree(elements); - // const root = merkleTree.getHexRoot(); + const root = merkleTree.getHexRoot(); - // const proof = merkleTree.getHexProof(elements[0]); + const proof = merkleTree.getHexProof(elements[0]); + const badProof = proof.slice(0, proof.length - 32); - // const leaf = merkleTree.bufToHex(elements[0]); + const leaf = merkleTree.bufToHex(elements[0]); - // const validProof = await merkleProof.verifyProof(proof, root, leaf); - // assert.isOk(validProof, "verifyProof did not return true for a valid proof given even number of leaves"); - }); - - it("should return true for a valid Merkle proof given odd number of leaves", async function () { - // const elements = ["a", "b", "c"].map(el => sha3(el)); - // const merkleTree = new MerkleTree(elements); - - // const root = merkleTree.getHexRoot(); - - // const proof = merkleTree.getHexProof(elements[0]); - - // const leaf = merkleTree.bufToHex(elements[0]); - - // const validProof = await merkleProof.verifyProof(proof, root, leaf); - // assert.isOk(validProof, "verifyProof did not return true for a valid proof given odd number of leaves"); - }); - - it("should return false for an invalid Merkle proof", async function() { - // const elements = ["a", "b", "c"].map(el => sha3(el)); - // const merkleTree = new MerkleTree(elements); - - // const root = merkleTree.getHexRoot(); - - // const proof = merkleTree.getHexProof(elements[0]); - // const badProof = proof.slice(0, proof.length - 32); - - // const leaf = merkleTree.bufToHex(elements[0]); - - // const validProof = await merkleProof.verifyProof(badProof, root, leaf); - // assert.isNotOk(validProof, "verifyProof did not return false for an invalid proof"); - }); + const result = await merkleProof.verifyProof(badProof, root, leaf); + assert.isNotOk(result, "verifyProof did not return false for an invalid proof"); }); + }); }); diff --git a/test/helpers/merkleTree.js b/test/helpers/merkleTree.js new file mode 100644 index 000000000..dc00eb12f --- /dev/null +++ b/test/helpers/merkleTree.js @@ -0,0 +1,135 @@ +import { sha3 } from "ethereumjs-util"; + +export default class MerkleTree { + constructor(elements) { + // Filter empty strings + this.elements = elements.filter(el => el); + + // Check if elements are 32 byte buffers + if (this.elements.some(el => el.length !== 32 || !Buffer.isBuffer(el))) { + throw new Error("Elements must be 32 byte buffers"); + } + + // Deduplicate elements + this.elements = this.bufDedup(this.elements); + // Sort elements + this.elements.sort(Buffer.compare); + + // Create layers + this.layers = this.getLayers(this.elements); + } + + getLayers(elements) { + if (elements.length == 0) { + return [[""]]; + } + + const layers = []; + layers.push(elements); + + // Get next layer until we reach the root + while (layers[layers.length - 1].length > 1) { + layers.push(this.getNextLayer(layers[layers.length - 1])); + } + + return layers; + } + + getNextLayer(elements) { + return elements.reduce((layer, el, idx, arr) => { + if (idx % 2 === 0) { + // Hash the current element with its pair element + layer.push(this.combinedHash(el, arr[idx + 1])); + } + + return layer; + }, []); + } + + combinedHash(first, second) { + if (!first) { return second; } + if (!second) { return first; } + + return sha3(this.sortAndConcat(first, second)); + } + + getRoot() { + return this.layers[this.layers.length - 1][0]; + } + + getHexRoot() { + return this.bufToHex(this.getRoot()); + } + + getProof(el) { + let idx = this.bufIndexOf(el, this.elements); + + if (idx === -1) { + throw new Error("Element does not exist in Merkle tree"); + } + + return this.layers.reduce((proof, layer) => { + const pairElement = this.getPairElement(idx, layer); + + if (pairElement) { + proof.push(pairElement); + } + + idx = Math.floor(idx / 2); + + return proof; + }, []); + } + + getHexProof(el) { + const proof = this.getProof(el); + + return this.bufArrToHex(proof); + } + + getPairElement(idx, layer) { + const pairIdx = idx % 2 === 0 ? idx + 1 : idx - 1; + + if (pairIdx < layer.length) { + return layer[pairIdx]; + } else { + return null; + } + } + + bufIndexOf(el, arr) { + for (let i = 0; i < arr.length; i++) { + if (el.equals(arr[i])) { + return i; + } + } + + return -1; + } + + bufDedup(elements) { + return elements.filter((el, idx) => { + return this.bufIndexOf(el, elements) === idx; + }); + } + + bufToHex(el) { + if (!Buffer.isBuffer(el)) { + throw new Error("Element is not a buffer"); + } + + return "0x" + el.toString("hex"); + } + + bufArrToHex(arr) { + if (arr.some(el => !Buffer.isBuffer(el))) { + throw new Error("Array is not an array of buffers"); + } + + return "0x" + arr.map(el => el.toString("hex")).join(""); + } + + sortAndConcat(...args) { + return Buffer.concat([...args].sort(Buffer.compare)); + } +}