diff --git a/contracts/utils/cryptography/MerkleProof.sol b/contracts/utils/cryptography/MerkleProof.sol index be1910062..ab268364f 100644 --- a/contracts/utils/cryptography/MerkleProof.sol +++ b/contracts/utils/cryptography/MerkleProof.sol @@ -105,7 +105,7 @@ library MerkleProof { * This version handles proofs in calldata with the default hashing function. */ function verifyCalldata(bytes32[] calldata proof, bytes32 root, bytes32 leaf) internal pure returns (bool) { - return processProof(proof, leaf) == root; + return processProofCalldata(proof, leaf) == root; } /** @@ -138,7 +138,7 @@ library MerkleProof { bytes32 leaf, function(bytes32, bytes32) view returns (bytes32) hasher ) internal view returns (bool) { - return processProof(proof, leaf, hasher) == root; + return processProofCalldata(proof, leaf, hasher) == root; } /** @@ -200,15 +200,16 @@ library MerkleProof { // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // the Merkle tree. uint256 leavesLen = leaves.length; + uint256 proofFlagsLen = proofFlags.length; // Check proof validity. - if (leavesLen + proof.length != proofFlags.length + 1) { + if (leavesLen + proof.length != proofFlagsLen + 1) { revert MerkleProofInvalidMultiproof(); } // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". - bytes32[] memory hashes = new bytes32[](proofFlags.length); + bytes32[] memory hashes = new bytes32[](proofFlagsLen); uint256 leafPos = 0; uint256 hashPos = 0; uint256 proofPos = 0; @@ -217,7 +218,7 @@ library MerkleProof { // get the next hash. // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // `proof` array. - for (uint256 i = 0; i < proofFlags.length; i++) { + for (uint256 i = 0; i < proofFlagsLen; i++) { bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 b = proofFlags[i] ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) @@ -225,12 +226,12 @@ library MerkleProof { hashes[i] = Hashes.commutativeKeccak256(a, b); } - if (proofFlags.length > 0) { + if (proofFlagsLen > 0) { if (proofPos != proof.length) { revert MerkleProofInvalidMultiproof(); } unchecked { - return hashes[proofFlags.length - 1]; + return hashes[proofFlagsLen - 1]; } } else if (leavesLen > 0) { return leaves[0]; @@ -280,15 +281,16 @@ library MerkleProof { // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // the Merkle tree. uint256 leavesLen = leaves.length; + uint256 proofFlagsLen = proofFlags.length; // Check proof validity. - if (leavesLen + proof.length != proofFlags.length + 1) { + if (leavesLen + proof.length != proofFlagsLen + 1) { revert MerkleProofInvalidMultiproof(); } // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". - bytes32[] memory hashes = new bytes32[](proofFlags.length); + bytes32[] memory hashes = new bytes32[](proofFlagsLen); uint256 leafPos = 0; uint256 hashPos = 0; uint256 proofPos = 0; @@ -297,7 +299,7 @@ library MerkleProof { // get the next hash. // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // `proof` array. - for (uint256 i = 0; i < proofFlags.length; i++) { + for (uint256 i = 0; i < proofFlagsLen; i++) { bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 b = proofFlags[i] ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) @@ -305,12 +307,12 @@ library MerkleProof { hashes[i] = hasher(a, b); } - if (proofFlags.length > 0) { + if (proofFlagsLen > 0) { if (proofPos != proof.length) { revert MerkleProofInvalidMultiproof(); } unchecked { - return hashes[proofFlags.length - 1]; + return hashes[proofFlagsLen - 1]; } } else if (leavesLen > 0) { return leaves[0]; @@ -331,9 +333,9 @@ library MerkleProof { bytes32[] calldata proof, bool[] calldata proofFlags, bytes32 root, - bytes32[] calldata leaves + bytes32[] memory leaves ) internal pure returns (bool) { - return processMultiProof(proof, proofFlags, leaves) == root; + return processMultiProofCalldata(proof, proofFlags, leaves) == root; } /** @@ -351,22 +353,23 @@ library MerkleProof { function processMultiProofCalldata( bytes32[] calldata proof, bool[] calldata proofFlags, - bytes32[] calldata leaves + bytes32[] memory leaves ) internal pure returns (bytes32 merkleRoot) { // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // the Merkle tree. uint256 leavesLen = leaves.length; + uint256 proofFlagsLen = proofFlags.length; // Check proof validity. - if (leavesLen + proof.length != proofFlags.length + 1) { + if (leavesLen + proof.length != proofFlagsLen + 1) { revert MerkleProofInvalidMultiproof(); } // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". - bytes32[] memory hashes = new bytes32[](proofFlags.length); + bytes32[] memory hashes = new bytes32[](proofFlagsLen); uint256 leafPos = 0; uint256 hashPos = 0; uint256 proofPos = 0; @@ -375,7 +378,7 @@ library MerkleProof { // get the next hash. // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // `proof` array. - for (uint256 i = 0; i < proofFlags.length; i++) { + for (uint256 i = 0; i < proofFlagsLen; i++) { bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 b = proofFlags[i] ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) @@ -383,12 +386,12 @@ library MerkleProof { hashes[i] = Hashes.commutativeKeccak256(a, b); } - if (proofFlags.length > 0) { + if (proofFlagsLen > 0) { if (proofPos != proof.length) { revert MerkleProofInvalidMultiproof(); } unchecked { - return hashes[proofFlags.length - 1]; + return hashes[proofFlagsLen - 1]; } } else if (leavesLen > 0) { return leaves[0]; @@ -409,10 +412,10 @@ library MerkleProof { bytes32[] calldata proof, bool[] calldata proofFlags, bytes32 root, - bytes32[] calldata leaves, + bytes32[] memory leaves, function(bytes32, bytes32) view returns (bytes32) hasher ) internal view returns (bool) { - return processMultiProof(proof, proofFlags, leaves, hasher) == root; + return processMultiProofCalldata(proof, proofFlags, leaves, hasher) == root; } /** @@ -430,7 +433,7 @@ library MerkleProof { function processMultiProofCalldata( bytes32[] calldata proof, bool[] calldata proofFlags, - bytes32[] calldata leaves, + bytes32[] memory leaves, function(bytes32, bytes32) view returns (bytes32) hasher ) internal view returns (bytes32 merkleRoot) { // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by @@ -438,15 +441,16 @@ library MerkleProof { // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // the Merkle tree. uint256 leavesLen = leaves.length; + uint256 proofFlagsLen = proofFlags.length; // Check proof validity. - if (leavesLen + proof.length != proofFlags.length + 1) { + if (leavesLen + proof.length != proofFlagsLen + 1) { revert MerkleProofInvalidMultiproof(); } // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". - bytes32[] memory hashes = new bytes32[](proofFlags.length); + bytes32[] memory hashes = new bytes32[](proofFlagsLen); uint256 leafPos = 0; uint256 hashPos = 0; uint256 proofPos = 0; @@ -455,7 +459,7 @@ library MerkleProof { // get the next hash. // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // `proof` array. - for (uint256 i = 0; i < proofFlags.length; i++) { + for (uint256 i = 0; i < proofFlagsLen; i++) { bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 b = proofFlags[i] ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) @@ -463,12 +467,12 @@ library MerkleProof { hashes[i] = hasher(a, b); } - if (proofFlags.length > 0) { + if (proofFlagsLen > 0) { if (proofPos != proof.length) { revert MerkleProofInvalidMultiproof(); } unchecked { - return hashes[proofFlags.length - 1]; + return hashes[proofFlagsLen - 1]; } } else if (leavesLen > 0) { return leaves[0]; diff --git a/scripts/generate/templates/MerkleProof.js b/scripts/generate/templates/MerkleProof.js index 768a981ee..45486bef3 100644 --- a/scripts/generate/templates/MerkleProof.js +++ b/scripts/generate/templates/MerkleProof.js @@ -56,7 +56,7 @@ function verify${suffix}(${(hash ? formatArgsMultiline : formatArgsSingleLine)( 'bytes32 leaf', hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, )}) internal ${visibility} returns (bool) { - return processProof(proof, leaf${hash ? `, ${hash}` : ''}) == root; + return processProof${suffix}(proof, leaf${hash ? `, ${hash}` : ''}) == root; } /** @@ -93,10 +93,10 @@ function multiProofVerify${suffix}(${formatArgsMultiline( `bytes32[] ${location} proof`, `bool[] ${location} proofFlags`, 'bytes32 root', - `bytes32[] ${location} leaves`, + `bytes32[] memory leaves`, hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, )}) internal ${visibility} returns (bool) { - return processMultiProof(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root; + return processMultiProof${suffix}(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root; } /** @@ -114,7 +114,7 @@ function multiProofVerify${suffix}(${formatArgsMultiline( function processMultiProof${suffix}(${formatArgsMultiline( `bytes32[] ${location} proof`, `bool[] ${location} proofFlags`, - `bytes32[] ${location} leaves`, + `bytes32[] memory leaves`, hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, )}) internal ${visibility} returns (bytes32 merkleRoot) { // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by @@ -122,15 +122,16 @@ function processMultiProof${suffix}(${formatArgsMultiline( // \`hashes\` array. At the end of the process, the last hash in the \`hashes\` array should contain the root of // the Merkle tree. uint256 leavesLen = leaves.length; + uint256 proofFlagsLen = proofFlags.length; // Check proof validity. - if (leavesLen + proof.length != proofFlags.length + 1) { + if (leavesLen + proof.length != proofFlagsLen + 1) { revert MerkleProofInvalidMultiproof(); } // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // \`xxx[xxxPos++]\`, which return the current value and increment the pointer, thus mimicking a queue's "pop". - bytes32[] memory hashes = new bytes32[](proofFlags.length); + bytes32[] memory hashes = new bytes32[](proofFlagsLen); uint256 leafPos = 0; uint256 hashPos = 0; uint256 proofPos = 0; @@ -139,7 +140,7 @@ function processMultiProof${suffix}(${formatArgsMultiline( // get the next hash. // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // \`proof\` array. - for (uint256 i = 0; i < proofFlags.length; i++) { + for (uint256 i = 0; i < proofFlagsLen; i++) { bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 b = proofFlags[i] ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) @@ -147,12 +148,12 @@ function processMultiProof${suffix}(${formatArgsMultiline( hashes[i] = ${hash ?? DEFAULT_HASH}(a, b); } - if (proofFlags.length > 0) { + if (proofFlagsLen > 0) { if (proofPos != proof.length) { revert MerkleProofInvalidMultiproof(); } unchecked { - return hashes[proofFlags.length - 1]; + return hashes[proofFlagsLen - 1]; } } else if (leavesLen > 0) { return leaves[0];