Get leaves from memory in processMultiProofCalldata (#5140)

Signed-off-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
Hadrien Croubois
2024-08-03 00:00:26 +02:00
parent a818284caf
commit de66e2ca51
2 changed files with 42 additions and 37 deletions

View File

@ -105,7 +105,7 @@ library MerkleProof {
* This version handles proofs in calldata with the default hashing function. * This version handles proofs in calldata with the default hashing function.
*/ */
function verifyCalldata(bytes32[] calldata proof, bytes32 root, bytes32 leaf) internal pure returns (bool) { function verifyCalldata(bytes32[] calldata proof, bytes32 root, bytes32 leaf) internal pure returns (bool) {
return processProof(proof, leaf) == root; return processProofCalldata(proof, leaf) == root;
} }
/** /**
@ -138,7 +138,7 @@ library MerkleProof {
bytes32 leaf, bytes32 leaf,
function(bytes32, bytes32) view returns (bytes32) hasher function(bytes32, bytes32) view returns (bytes32) hasher
) internal view returns (bool) { ) internal view returns (bool) {
return processProof(proof, leaf, hasher) == root; return processProofCalldata(proof, leaf, hasher) == root;
} }
/** /**
@ -200,15 +200,16 @@ library MerkleProof {
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
// the Merkle tree. // the Merkle tree.
uint256 leavesLen = leaves.length; uint256 leavesLen = leaves.length;
uint256 proofFlagsLen = proofFlags.length;
// Check proof validity. // Check proof validity.
if (leavesLen + proof.length != proofFlags.length + 1) { if (leavesLen + proof.length != proofFlagsLen + 1) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
bytes32[] memory hashes = new bytes32[](proofFlags.length); bytes32[] memory hashes = new bytes32[](proofFlagsLen);
uint256 leafPos = 0; uint256 leafPos = 0;
uint256 hashPos = 0; uint256 hashPos = 0;
uint256 proofPos = 0; uint256 proofPos = 0;
@ -217,7 +218,7 @@ library MerkleProof {
// get the next hash. // get the next hash.
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
// `proof` array. // `proof` array.
for (uint256 i = 0; i < proofFlags.length; i++) { for (uint256 i = 0; i < proofFlagsLen; i++) {
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlags[i] bytes32 b = proofFlags[i]
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@ -225,12 +226,12 @@ library MerkleProof {
hashes[i] = Hashes.commutativeKeccak256(a, b); hashes[i] = Hashes.commutativeKeccak256(a, b);
} }
if (proofFlags.length > 0) { if (proofFlagsLen > 0) {
if (proofPos != proof.length) { if (proofPos != proof.length) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
unchecked { unchecked {
return hashes[proofFlags.length - 1]; return hashes[proofFlagsLen - 1];
} }
} else if (leavesLen > 0) { } else if (leavesLen > 0) {
return leaves[0]; return leaves[0];
@ -280,15 +281,16 @@ library MerkleProof {
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
// the Merkle tree. // the Merkle tree.
uint256 leavesLen = leaves.length; uint256 leavesLen = leaves.length;
uint256 proofFlagsLen = proofFlags.length;
// Check proof validity. // Check proof validity.
if (leavesLen + proof.length != proofFlags.length + 1) { if (leavesLen + proof.length != proofFlagsLen + 1) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
bytes32[] memory hashes = new bytes32[](proofFlags.length); bytes32[] memory hashes = new bytes32[](proofFlagsLen);
uint256 leafPos = 0; uint256 leafPos = 0;
uint256 hashPos = 0; uint256 hashPos = 0;
uint256 proofPos = 0; uint256 proofPos = 0;
@ -297,7 +299,7 @@ library MerkleProof {
// get the next hash. // get the next hash.
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
// `proof` array. // `proof` array.
for (uint256 i = 0; i < proofFlags.length; i++) { for (uint256 i = 0; i < proofFlagsLen; i++) {
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlags[i] bytes32 b = proofFlags[i]
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@ -305,12 +307,12 @@ library MerkleProof {
hashes[i] = hasher(a, b); hashes[i] = hasher(a, b);
} }
if (proofFlags.length > 0) { if (proofFlagsLen > 0) {
if (proofPos != proof.length) { if (proofPos != proof.length) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
unchecked { unchecked {
return hashes[proofFlags.length - 1]; return hashes[proofFlagsLen - 1];
} }
} else if (leavesLen > 0) { } else if (leavesLen > 0) {
return leaves[0]; return leaves[0];
@ -331,9 +333,9 @@ library MerkleProof {
bytes32[] calldata proof, bytes32[] calldata proof,
bool[] calldata proofFlags, bool[] calldata proofFlags,
bytes32 root, bytes32 root,
bytes32[] calldata leaves bytes32[] memory leaves
) internal pure returns (bool) { ) internal pure returns (bool) {
return processMultiProof(proof, proofFlags, leaves) == root; return processMultiProofCalldata(proof, proofFlags, leaves) == root;
} }
/** /**
@ -351,22 +353,23 @@ library MerkleProof {
function processMultiProofCalldata( function processMultiProofCalldata(
bytes32[] calldata proof, bytes32[] calldata proof,
bool[] calldata proofFlags, bool[] calldata proofFlags,
bytes32[] calldata leaves bytes32[] memory leaves
) internal pure returns (bytes32 merkleRoot) { ) internal pure returns (bytes32 merkleRoot) {
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
// consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the // consuming and producing values on a queue. The queue starts with the `leaves` array, then goes onto the
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
// the Merkle tree. // the Merkle tree.
uint256 leavesLen = leaves.length; uint256 leavesLen = leaves.length;
uint256 proofFlagsLen = proofFlags.length;
// Check proof validity. // Check proof validity.
if (leavesLen + proof.length != proofFlags.length + 1) { if (leavesLen + proof.length != proofFlagsLen + 1) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
bytes32[] memory hashes = new bytes32[](proofFlags.length); bytes32[] memory hashes = new bytes32[](proofFlagsLen);
uint256 leafPos = 0; uint256 leafPos = 0;
uint256 hashPos = 0; uint256 hashPos = 0;
uint256 proofPos = 0; uint256 proofPos = 0;
@ -375,7 +378,7 @@ library MerkleProof {
// get the next hash. // get the next hash.
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
// `proof` array. // `proof` array.
for (uint256 i = 0; i < proofFlags.length; i++) { for (uint256 i = 0; i < proofFlagsLen; i++) {
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlags[i] bytes32 b = proofFlags[i]
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@ -383,12 +386,12 @@ library MerkleProof {
hashes[i] = Hashes.commutativeKeccak256(a, b); hashes[i] = Hashes.commutativeKeccak256(a, b);
} }
if (proofFlags.length > 0) { if (proofFlagsLen > 0) {
if (proofPos != proof.length) { if (proofPos != proof.length) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
unchecked { unchecked {
return hashes[proofFlags.length - 1]; return hashes[proofFlagsLen - 1];
} }
} else if (leavesLen > 0) { } else if (leavesLen > 0) {
return leaves[0]; return leaves[0];
@ -409,10 +412,10 @@ library MerkleProof {
bytes32[] calldata proof, bytes32[] calldata proof,
bool[] calldata proofFlags, bool[] calldata proofFlags,
bytes32 root, bytes32 root,
bytes32[] calldata leaves, bytes32[] memory leaves,
function(bytes32, bytes32) view returns (bytes32) hasher function(bytes32, bytes32) view returns (bytes32) hasher
) internal view returns (bool) { ) internal view returns (bool) {
return processMultiProof(proof, proofFlags, leaves, hasher) == root; return processMultiProofCalldata(proof, proofFlags, leaves, hasher) == root;
} }
/** /**
@ -430,7 +433,7 @@ library MerkleProof {
function processMultiProofCalldata( function processMultiProofCalldata(
bytes32[] calldata proof, bytes32[] calldata proof,
bool[] calldata proofFlags, bool[] calldata proofFlags,
bytes32[] calldata leaves, bytes32[] memory leaves,
function(bytes32, bytes32) view returns (bytes32) hasher function(bytes32, bytes32) view returns (bytes32) hasher
) internal view returns (bytes32 merkleRoot) { ) internal view returns (bytes32 merkleRoot) {
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
@ -438,15 +441,16 @@ library MerkleProof {
// `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of // `hashes` array. At the end of the process, the last hash in the `hashes` array should contain the root of
// the Merkle tree. // the Merkle tree.
uint256 leavesLen = leaves.length; uint256 leavesLen = leaves.length;
uint256 proofFlagsLen = proofFlags.length;
// Check proof validity. // Check proof validity.
if (leavesLen + proof.length != proofFlags.length + 1) { if (leavesLen + proof.length != proofFlagsLen + 1) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop". // `xxx[xxxPos++]`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
bytes32[] memory hashes = new bytes32[](proofFlags.length); bytes32[] memory hashes = new bytes32[](proofFlagsLen);
uint256 leafPos = 0; uint256 leafPos = 0;
uint256 hashPos = 0; uint256 hashPos = 0;
uint256 proofPos = 0; uint256 proofPos = 0;
@ -455,7 +459,7 @@ library MerkleProof {
// get the next hash. // get the next hash.
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
// `proof` array. // `proof` array.
for (uint256 i = 0; i < proofFlags.length; i++) { for (uint256 i = 0; i < proofFlagsLen; i++) {
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlags[i] bytes32 b = proofFlags[i]
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@ -463,12 +467,12 @@ library MerkleProof {
hashes[i] = hasher(a, b); hashes[i] = hasher(a, b);
} }
if (proofFlags.length > 0) { if (proofFlagsLen > 0) {
if (proofPos != proof.length) { if (proofPos != proof.length) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
unchecked { unchecked {
return hashes[proofFlags.length - 1]; return hashes[proofFlagsLen - 1];
} }
} else if (leavesLen > 0) { } else if (leavesLen > 0) {
return leaves[0]; return leaves[0];

View File

@ -56,7 +56,7 @@ function verify${suffix}(${(hash ? formatArgsMultiline : formatArgsSingleLine)(
'bytes32 leaf', 'bytes32 leaf',
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
)}) internal ${visibility} returns (bool) { )}) internal ${visibility} returns (bool) {
return processProof(proof, leaf${hash ? `, ${hash}` : ''}) == root; return processProof${suffix}(proof, leaf${hash ? `, ${hash}` : ''}) == root;
} }
/** /**
@ -93,10 +93,10 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
`bytes32[] ${location} proof`, `bytes32[] ${location} proof`,
`bool[] ${location} proofFlags`, `bool[] ${location} proofFlags`,
'bytes32 root', 'bytes32 root',
`bytes32[] ${location} leaves`, `bytes32[] memory leaves`,
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
)}) internal ${visibility} returns (bool) { )}) internal ${visibility} returns (bool) {
return processMultiProof(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root; return processMultiProof${suffix}(proof, proofFlags, leaves${hash ? `, ${hash}` : ''}) == root;
} }
/** /**
@ -114,7 +114,7 @@ function multiProofVerify${suffix}(${formatArgsMultiline(
function processMultiProof${suffix}(${formatArgsMultiline( function processMultiProof${suffix}(${formatArgsMultiline(
`bytes32[] ${location} proof`, `bytes32[] ${location} proof`,
`bool[] ${location} proofFlags`, `bool[] ${location} proofFlags`,
`bytes32[] ${location} leaves`, `bytes32[] memory leaves`,
hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`, hash && `function(bytes32, bytes32) view returns (bytes32) ${hash}`,
)}) internal ${visibility} returns (bytes32 merkleRoot) { )}) internal ${visibility} returns (bytes32 merkleRoot) {
// This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by // This function rebuilds the root hash by traversing the tree up from the leaves. The root is rebuilt by
@ -122,15 +122,16 @@ function processMultiProof${suffix}(${formatArgsMultiline(
// \`hashes\` array. At the end of the process, the last hash in the \`hashes\` array should contain the root of // \`hashes\` array. At the end of the process, the last hash in the \`hashes\` array should contain the root of
// the Merkle tree. // the Merkle tree.
uint256 leavesLen = leaves.length; uint256 leavesLen = leaves.length;
uint256 proofFlagsLen = proofFlags.length;
// Check proof validity. // Check proof validity.
if (leavesLen + proof.length != proofFlags.length + 1) { if (leavesLen + proof.length != proofFlagsLen + 1) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
// The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using // The xxxPos values are "pointers" to the next value to consume in each array. All accesses are done using
// \`xxx[xxxPos++]\`, which return the current value and increment the pointer, thus mimicking a queue's "pop". // \`xxx[xxxPos++]\`, which return the current value and increment the pointer, thus mimicking a queue's "pop".
bytes32[] memory hashes = new bytes32[](proofFlags.length); bytes32[] memory hashes = new bytes32[](proofFlagsLen);
uint256 leafPos = 0; uint256 leafPos = 0;
uint256 hashPos = 0; uint256 hashPos = 0;
uint256 proofPos = 0; uint256 proofPos = 0;
@ -139,7 +140,7 @@ function processMultiProof${suffix}(${formatArgsMultiline(
// get the next hash. // get the next hash.
// - depending on the flag, either another value from the "main queue" (merging branches) or an element from the // - depending on the flag, either another value from the "main queue" (merging branches) or an element from the
// \`proof\` array. // \`proof\` array.
for (uint256 i = 0; i < proofFlags.length; i++) { for (uint256 i = 0; i < proofFlagsLen; i++) {
bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]; bytes32 a = leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++];
bytes32 b = proofFlags[i] bytes32 b = proofFlags[i]
? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++]) ? (leafPos < leavesLen ? leaves[leafPos++] : hashes[hashPos++])
@ -147,12 +148,12 @@ function processMultiProof${suffix}(${formatArgsMultiline(
hashes[i] = ${hash ?? DEFAULT_HASH}(a, b); hashes[i] = ${hash ?? DEFAULT_HASH}(a, b);
} }
if (proofFlags.length > 0) { if (proofFlagsLen > 0) {
if (proofPos != proof.length) { if (proofPos != proof.length) {
revert MerkleProofInvalidMultiproof(); revert MerkleProofInvalidMultiproof();
} }
unchecked { unchecked {
return hashes[proofFlags.length - 1]; return hashes[proofFlagsLen - 1];
} }
} else if (leavesLen > 0) { } else if (leavesLen > 0) {
return leaves[0]; return leaves[0];