From 71bc0f7774abb500273010ffeea971c1f82c6b1a Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Fri, 28 Feb 2025 21:22:56 +0100 Subject: [PATCH] Add function to update a leaf in a MerkleTree structure (#5453) Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com> --- .changeset/good-zebras-ring.md | 5 + contracts/mocks/MerkleTreeMock.sol | 8 ++ contracts/utils/structs/MerkleTree.sol | 92 +++++++++++++++++ test/utils/structs/MerkleTree.test.js | 136 ++++++++++++++++++++----- 4 files changed, 213 insertions(+), 28 deletions(-) create mode 100644 .changeset/good-zebras-ring.md diff --git a/.changeset/good-zebras-ring.md b/.changeset/good-zebras-ring.md new file mode 100644 index 000000000..776fdc790 --- /dev/null +++ b/.changeset/good-zebras-ring.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`MerkleTree`: Add an update function that replaces a previously inserted leaf with a new value, updating the tree root along the way. diff --git a/contracts/mocks/MerkleTreeMock.sol b/contracts/mocks/MerkleTreeMock.sol index dcde6b658..48ee1a6e4 100644 --- a/contracts/mocks/MerkleTreeMock.sol +++ b/contracts/mocks/MerkleTreeMock.sol @@ -14,6 +14,7 @@ contract MerkleTreeMock { bytes32 public root; event LeafInserted(bytes32 leaf, uint256 index, bytes32 root); + event LeafUpdated(bytes32 oldLeaf, bytes32 newLeaf, uint256 index, bytes32 root); function setup(uint8 _depth, bytes32 _zero) public { root = _tree.setup(_depth, _zero); @@ -25,6 +26,13 @@ contract MerkleTreeMock { root = currentRoot; } + function update(uint256 index, bytes32 oldValue, bytes32 newValue, bytes32[] memory proof) public { + (bytes32 oldRoot, bytes32 newRoot) = _tree.update(index, oldValue, newValue, proof); + if (oldRoot != root) revert MerkleTree.MerkleTreeUpdateInvalidProof(); + emit LeafUpdated(oldValue, newValue, index, newRoot); + root = newRoot; + } + function depth() public view returns (uint256) { return _tree.depth(); } diff --git a/contracts/utils/structs/MerkleTree.sol b/contracts/utils/structs/MerkleTree.sol index a52cfc91d..c48a1247a 100644 --- a/contracts/utils/structs/MerkleTree.sol +++ b/contracts/utils/structs/MerkleTree.sol @@ -6,6 +6,7 @@ pragma solidity ^0.8.20; import {Hashes} from "../cryptography/Hashes.sol"; import {Arrays} from "../Arrays.sol"; import {Panic} from "../Panic.sol"; +import {StorageSlot} from "../StorageSlot.sol"; /** * @dev Library for managing https://wikipedia.org/wiki/Merkle_Tree[Merkle Tree] data structures. @@ -27,6 +28,12 @@ import {Panic} from "../Panic.sol"; * _Available since v5.1._ */ library MerkleTree { + /// @dev Error emitted when trying to update a leaf that was not previously pushed. + error MerkleTreeUpdateInvalidIndex(uint256 index, uint256 length); + + /// @dev Error emitted when the proof used during an update is invalid (could not reproduce the side). + error MerkleTreeUpdateInvalidProof(); + /** * @dev A complete `bytes32` Merkle tree. * @@ -166,6 +173,91 @@ library MerkleTree { return (index, currentLevelHash); } + /** + * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old" + * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old + * root is the last known one. + * + * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is + * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render + * all "in flight" updates invalid. + * + * This variant uses {Hashes-commutativeKeccak256} to hash internal nodes. It should only be used on merkle trees + * that were setup using the same (default) hashing function (i.e. by calling + * {xref-MerkleTree-setup-struct-MerkleTree-Bytes32PushTree-uint8-bytes32-}[the default setup] function). + */ + function update( + Bytes32PushTree storage self, + uint256 index, + bytes32 oldValue, + bytes32 newValue, + bytes32[] memory proof + ) internal returns (bytes32 oldRoot, bytes32 newRoot) { + return update(self, index, oldValue, newValue, proof, Hashes.commutativeKeccak256); + } + + /** + * @dev Change the value of the leaf at position `index` from `oldValue` to `newValue`. Returns the recomputed "old" + * root (before the update) and "new" root (after the update). The caller must verify that the reconstructed old + * root is the last known one. + * + * The `proof` must be an up-to-date inclusion proof for the leaf being update. This means that this function is + * vulnerable to front-running. Any {push} or {update} operation (that changes the root of the tree) would render + * all "in flight" updates invalid. + * + * This variant uses a custom hashing function to hash internal nodes. It should only be called with the same + * function as the one used during the initial setup of the merkle tree. + */ + function update( + Bytes32PushTree storage self, + uint256 index, + bytes32 oldValue, + bytes32 newValue, + bytes32[] memory proof, + function(bytes32, bytes32) view returns (bytes32) fnHash + ) internal returns (bytes32 oldRoot, bytes32 newRoot) { + unchecked { + // Check index range + uint256 length = self._nextLeafIndex; + if (index >= length) revert MerkleTreeUpdateInvalidIndex(index, length); + + // Cache read + uint256 treeDepth = depth(self); + + // Workaround stack too deep + bytes32[] storage sides = self._sides; + + // This cannot overflow because: 0 <= index < length + uint256 lastIndex = length - 1; + uint256 currentIndex = index; + bytes32 currentLevelHashOld = oldValue; + bytes32 currentLevelHashNew = newValue; + for (uint32 i = 0; i < treeDepth; i++) { + bool isLeft = currentIndex % 2 == 0; + + lastIndex >>= 1; + currentIndex >>= 1; + + if (isLeft && currentIndex == lastIndex) { + StorageSlot.Bytes32Slot storage side = Arrays.unsafeAccess(sides, i); + if (side.value != currentLevelHashOld) revert MerkleTreeUpdateInvalidProof(); + side.value = currentLevelHashNew; + } + + bytes32 sibling = proof[i]; + currentLevelHashOld = fnHash( + isLeft ? currentLevelHashOld : sibling, + isLeft ? sibling : currentLevelHashOld + ); + currentLevelHashNew = fnHash( + isLeft ? currentLevelHashNew : sibling, + isLeft ? sibling : currentLevelHashNew + ); + } + return (currentLevelHashOld, currentLevelHashNew); + } + } + /** * @dev Tree's depth (set at initialization) */ diff --git a/test/utils/structs/MerkleTree.test.js b/test/utils/structs/MerkleTree.test.js index bec39ceea..f0380ed02 100644 --- a/test/utils/structs/MerkleTree.test.js +++ b/test/utils/structs/MerkleTree.test.js @@ -5,18 +5,23 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); const { StandardMerkleTree } = require('@openzeppelin/merkle-tree'); const { generators } = require('../../helpers/random'); +const { range } = require('../../helpers/iterate'); -const makeTree = (leaves = [ethers.ZeroHash]) => +const DEPTH = 4; // 16 slots + +const makeTree = (leaves = [], length = 2 ** DEPTH, zero = ethers.ZeroHash) => StandardMerkleTree.of( - leaves.map(leaf => [leaf]), + [] + .concat( + leaves, + Array.from({ length: length - leaves.length }, () => zero), + ) + .map(leaf => [leaf]), ['bytes32'], { sortLeaves: false }, ); -const hashLeaf = leaf => makeTree().leafHash([leaf]); - -const DEPTH = 4n; // 16 slots -const ZERO = hashLeaf(ethers.ZeroHash); +const ZERO = makeTree().leafHash([ethers.ZeroHash]); async function fixture() { const mock = await ethers.deployContract('MerkleTreeMock'); @@ -30,57 +35,132 @@ describe('MerkleTree', function () { }); it('sets initial values at setup', async function () { - const merkleTree = makeTree(Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash)); + const merkleTree = makeTree(); - expect(await this.mock.root()).to.equal(merkleTree.root); - expect(await this.mock.depth()).to.equal(DEPTH); - expect(await this.mock.nextLeafIndex()).to.equal(0n); + await expect(this.mock.root()).to.eventually.equal(merkleTree.root); + await expect(this.mock.depth()).to.eventually.equal(DEPTH); + await expect(this.mock.nextLeafIndex()).to.eventually.equal(0n); }); describe('push', function () { - it('tree is correctly updated', async function () { - const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash); + it('pushing correctly updates the tree', async function () { + const leaves = []; // for each leaf slot - for (const i in leaves) { - // generate random leaf and hash it - const hashedLeaf = hashLeaf((leaves[i] = generators.bytes32())); + for (const i in range(2 ** DEPTH)) { + // generate random leaf + leaves.push(generators.bytes32()); - // update leaf list and rebuild tree. + // rebuild tree. const tree = makeTree(leaves); + const hash = tree.leafHash(tree.at(i)); // push value to tree - await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, i, tree.root); + await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, i, tree.root); // check tree - expect(await this.mock.root()).to.equal(tree.root); - expect(await this.mock.nextLeafIndex()).to.equal(BigInt(i) + 1n); + await expect(this.mock.root()).to.eventually.equal(tree.root); + await expect(this.mock.nextLeafIndex()).to.eventually.equal(BigInt(i) + 1n); } }); - it('revert when tree is full', async function () { + it('pushing to a full tree reverts', async function () { await Promise.all(Array.from({ length: 2 ** Number(DEPTH) }).map(() => this.mock.push(ethers.ZeroHash))); await expect(this.mock.push(ethers.ZeroHash)).to.be.revertedWithPanic(PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED); }); }); + describe('update', function () { + for (const { leafCount, leafIndex } of range(2 ** DEPTH + 1).flatMap(leafCount => + range(leafCount).map(leafIndex => ({ leafCount, leafIndex })), + )) + it(`updating a leaf correctly updates the tree (leaf #${leafIndex + 1}/${leafCount})`, async function () { + // initial tree + const leaves = Array.from({ length: leafCount }, generators.bytes32); + const oldTree = makeTree(leaves); + + // fill tree and verify root + for (const i in leaves) { + await this.mock.push(oldTree.leafHash(oldTree.at(i))); + } + await expect(this.mock.root()).to.eventually.equal(oldTree.root); + + // create updated tree + leaves[leafIndex] = generators.bytes32(); + const newTree = makeTree(leaves); + + const oldLeafHash = oldTree.leafHash(oldTree.at(leafIndex)); + const newLeafHash = newTree.leafHash(newTree.at(leafIndex)); + + // perform update + await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, oldTree.getProof(leafIndex))) + .to.emit(this.mock, 'LeafUpdated') + .withArgs(oldLeafHash, newLeafHash, leafIndex, newTree.root); + + // verify updated root + await expect(this.mock.root()).to.eventually.equal(newTree.root); + + // if there is still room in the tree, fill it + for (const i of range(leafCount, 2 ** DEPTH)) { + // push new value and rebuild tree + leaves.push(generators.bytes32()); + const nextTree = makeTree(leaves); + + // push and verify root + await this.mock.push(nextTree.leafHash(nextTree.at(i))); + await expect(this.mock.root()).to.eventually.equal(nextTree.root); + } + }); + + it('replacing a leaf that was not previously pushed reverts', async function () { + // changing leaf 0 on an empty tree + await expect(this.mock.update(1, ZERO, ZERO, [])) + .to.be.revertedWithCustomError(this.mock, 'MerkleTreeUpdateInvalidIndex') + .withArgs(1, 0); + }); + + it('replacing a leaf using an invalid proof reverts', async function () { + const leafCount = 4; + const leafIndex = 2; + + const leaves = Array.from({ length: leafCount }, generators.bytes32); + const tree = makeTree(leaves); + + // fill tree and verify root + for (const i in leaves) { + await this.mock.push(tree.leafHash(tree.at(i))); + } + await expect(this.mock.root()).to.eventually.equal(tree.root); + + const oldLeafHash = tree.leafHash(tree.at(leafIndex)); + const newLeafHash = generators.bytes32(); + const proof = tree.getProof(leafIndex); + // invalid proof (tamper) + proof[1] = generators.bytes32(); + + await expect(this.mock.update(leafIndex, oldLeafHash, newLeafHash, proof)).to.be.revertedWithCustomError( + this.mock, + 'MerkleTreeUpdateInvalidProof', + ); + }); + }); + it('reset', async function () { // empty tree - const zeroLeaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash); - const zeroTree = makeTree(zeroLeaves); + const emptyTree = makeTree(); // tree with one element - const leaves = Array.from({ length: 2 ** Number(DEPTH) }, () => ethers.ZeroHash); - const hashedLeaf = hashLeaf((leaves[0] = generators.bytes32())); // fill first leaf and hash it + const leaves = [generators.bytes32()]; const tree = makeTree(leaves); + const hash = tree.leafHash(tree.at(0)); // root should be that of a zero tree - expect(await this.mock.root()).to.equal(zeroTree.root); + expect(await this.mock.root()).to.equal(emptyTree.root); expect(await this.mock.nextLeafIndex()).to.equal(0n); // push leaf and check root - await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root); + await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root); expect(await this.mock.root()).to.equal(tree.root); expect(await this.mock.nextLeafIndex()).to.equal(1n); @@ -88,11 +168,11 @@ describe('MerkleTree', function () { // reset tree await this.mock.setup(DEPTH, ZERO); - expect(await this.mock.root()).to.equal(zeroTree.root); + expect(await this.mock.root()).to.equal(emptyTree.root); expect(await this.mock.nextLeafIndex()).to.equal(0n); // re-push leaf and check root - await expect(this.mock.push(hashedLeaf)).to.emit(this.mock, 'LeafInserted').withArgs(hashedLeaf, 0, tree.root); + await expect(this.mock.push(hash)).to.emit(this.mock, 'LeafInserted').withArgs(hash, 0, tree.root); expect(await this.mock.root()).to.equal(tree.root); expect(await this.mock.nextLeafIndex()).to.equal(1n);