Refactor Heap.sol to remove index and lookup (#5190)

Co-authored-by: Ernesto García <ernestognw@gmail.com>
This commit is contained in:
Hadrien Croubois
2024-09-19 14:29:39 +02:00
committed by GitHub
parent f20981528f
commit 3cfebcb5c4
7 changed files with 155 additions and 1008 deletions

View File

@ -1,33 +1,25 @@
// SPDX-License-Identifier: MIT
// This file was procedurally generated from scripts/generate/templates/Heap.js.
pragma solidity ^0.8.20;
import {Math} from "../math/Math.sol";
import {SafeCast} from "../math/SafeCast.sol";
import {Comparators} from "../Comparators.sol";
import {Arrays} from "../Arrays.sol";
import {Panic} from "../Panic.sol";
import {StorageSlot} from "../StorageSlot.sol";
/**
* @dev Library for managing https://en.wikipedia.org/wiki/Binary_heap[binary heap] that can be used as
* https://en.wikipedia.org/wiki/Priority_queue[priority queue].
*
* Heaps are represented as an array of Node objects. This array stores two overlapping structures:
* * A tree structure where the first element (index 0) is the root, and where the node at index i is the child of the
* node at index (i-1)/2 and the father of nodes at index 2*i+1 and 2*i+2. Each node stores the index (in the array)
* where the corresponding value is stored.
* * A list of payloads values where each index contains a value and a lookup index. The type of the value depends on
* the variant being used. The lookup is the index of the node (in the tree) that points to this value.
*
* Some invariants:
* ```
* i == heap.data[heap.data[i].index].lookup // for all indices i
* i == heap.data[heap.data[i].lookup].index // for all indices i
* ```
* Heaps are represented as an tree of values where the first element (index 0) is the root, and where the node at
* index i is the child of the node at index (i-1)/2 and the father of nodes at index 2*i+1 and 2*i+2. Each node
* stores an element of the heap.
*
* The structure is ordered so that each node is bigger than its parent. An immediate consequence is that the
* highest priority value is the one at the root. This value can be looked up in constant time (O(1)) at
* `heap.data[heap.data[0].index].value`
* `heap.tree[0].value`
*
* The structure is designed to perform the following operations with the corresponding complexities:
*
@ -37,8 +29,13 @@ import {Panic} from "../Panic.sol";
* * replace (replace the highest priority value with a new value): O(log(n))
* * length (get the number of elements): O(1)
* * clear (remove all elements): O(1)
*
* IMPORTANT: This library allows for the use of custom comparator functions. Given that manipulating
* memory can lead to unexpected behavior. Consider verifying that the comparator does not manipulate
* the Heap's state directly and that it follows the Solidity memory safety rules.
*/
library Heap {
using Arrays for *;
using Math for *;
using SafeCast for *;
@ -48,24 +45,15 @@ library Heap {
* Each element of that structure uses 2 storage slots.
*/
struct Uint256Heap {
Uint256HeapNode[] data;
}
/**
* @dev Internal node type for Uint256Heap. Stores a value of type uint256.
*/
struct Uint256HeapNode {
uint256 value;
uint64 index; // position -> value
uint64 lookup; // value -> position
uint256[] tree;
}
/**
* @dev Lookup the root element of the heap.
*/
function peek(Uint256Heap storage self) internal view returns (uint256) {
// self.data[0] will `ARRAY_ACCESS_OUT_OF_BOUNDS` panic if heap is empty.
return _unsafeNodeAccess(self, self.data[0].index).value;
// self.tree[0] will `ARRAY_ACCESS_OUT_OF_BOUNDS` panic if heap is empty.
return self.tree[0];
}
/**
@ -89,44 +77,19 @@ library Heap {
function(uint256, uint256) view returns (bool) comp
) internal returns (uint256) {
unchecked {
uint64 size = length(self);
uint256 size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
uint64 last = size - 1;
// cache
uint256 rootValue = self.tree.unsafeAccess(0).value;
uint256 lastValue = self.tree.unsafeAccess(size - 1).value;
// get root location (in the data array) and value
Uint256HeapNode storage rootNode = _unsafeNodeAccess(self, 0);
uint64 rootIdx = rootNode.index;
Uint256HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx);
Uint256HeapNode storage lastNode = _unsafeNodeAccess(self, last);
uint256 rootDataValue = rootData.value;
// swap last leaf with root, shrink tree and re-heapify
self.tree.pop();
self.tree.unsafeAccess(0).value = lastValue;
_siftDown(self, size - 1, 0, lastValue, comp);
// if root is not the last element of the data array (that will get popped), reorder the data array.
if (rootIdx != last) {
// get details about the value stored in the last element of the array (that will get popped)
uint64 lastDataIdx = lastNode.lookup;
uint256 lastDataValue = lastNode.value;
// copy these values to the location of the root (that is safe, and that we no longer use)
rootData.value = lastDataValue;
rootData.lookup = lastDataIdx;
// update the tree node that used to point to that last element (value now located where the root was)
_unsafeNodeAccess(self, lastDataIdx).index = rootIdx;
}
// get last leaf location (in the data array) and value
uint64 lastIdx = lastNode.index;
uint256 lastValue = _unsafeNodeAccess(self, lastIdx).value;
// move the last leaf to the root, pop last leaf ...
rootNode.index = lastIdx;
_unsafeNodeAccess(self, lastIdx).lookup = 0;
self.data.pop();
// ... and heapify
_siftDown(self, last, 0, lastValue, comp);
// return root value
return rootDataValue;
return rootValue;
}
}
@ -151,10 +114,10 @@ library Heap {
uint256 value,
function(uint256, uint256) view returns (bool) comp
) internal {
uint64 size = length(self);
if (size == type(uint64).max) Panic.panic(Panic.RESOURCE_ERROR);
uint256 size = length(self);
self.data.push(Uint256HeapNode({index: size, lookup: size, value: value}));
// push new item and re-heapify
self.tree.push(value);
_siftUp(self, size, value, comp);
}
@ -181,396 +144,108 @@ library Heap {
uint256 newValue,
function(uint256, uint256) view returns (bool) comp
) internal returns (uint256) {
uint64 size = length(self);
uint256 size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
// position of the node that holds the data for the root
uint64 rootIdx = _unsafeNodeAccess(self, 0).index;
// storage pointer to the node that holds the data for the root
Uint256HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx);
// cache
uint256 oldValue = self.tree.unsafeAccess(0).value;
// cache old value and replace it
uint256 oldValue = rootData.value;
rootData.value = newValue;
// re-heapify
// replace and re-heapify
self.tree.unsafeAccess(0).value = newValue;
_siftDown(self, size, 0, newValue, comp);
// return old root value
return oldValue;
}
/**
* @dev Returns the number of elements in the heap.
*/
function length(Uint256Heap storage self) internal view returns (uint64) {
return self.data.length.toUint64();
function length(Uint256Heap storage self) internal view returns (uint256) {
return self.tree.length;
}
/**
* @dev Removes all elements in the heap.
*/
function clear(Uint256Heap storage self) internal {
Uint256HeapNode[] storage data = self.data;
assembly ("memory-safe") {
sstore(data.slot, 0)
}
self.tree.unsafeSetLength(0);
}
/**
* @dev Swap node `i` and `j` in the tree.
*/
function _swap(Uint256Heap storage self, uint64 i, uint64 j) private {
Uint256HeapNode storage ni = _unsafeNodeAccess(self, i);
Uint256HeapNode storage nj = _unsafeNodeAccess(self, j);
uint64 ii = ni.index;
uint64 jj = nj.index;
// update pointers to the data (swap the value)
ni.index = jj;
nj.index = ii;
// update lookup pointers for consistency
_unsafeNodeAccess(self, ii).lookup = j;
_unsafeNodeAccess(self, jj).lookup = i;
function _swap(Uint256Heap storage self, uint256 i, uint256 j) private {
StorageSlot.Uint256Slot storage ni = self.tree.unsafeAccess(i);
StorageSlot.Uint256Slot storage nj = self.tree.unsafeAccess(j);
(ni.value, nj.value) = (nj.value, ni.value);
}
/**
* @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a
* @dev Perform heap maintenance on `self`, starting at `index` (with the `value`), using `comp` as a
* comparator, and moving toward the leaves of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. `length`
* and `value` could be extracted from `self` and `pos`, but that would require redundant storage read. These
* and `value` could be extracted from `self` and `index`, but that would require redundant storage read. These
* parameters are not verified. It is the caller role to make sure the parameters are correct.
*/
function _siftDown(
Uint256Heap storage self,
uint64 size,
uint64 pos,
uint256 size,
uint256 index,
uint256 value,
function(uint256, uint256) view returns (bool) comp
) private {
uint256 left = 2 * pos + 1; // this could overflow uint64
uint256 right = 2 * pos + 2; // this could overflow uint64
// Check if there is a risk of overflow when computing the indices of the child nodes. If that is the case,
// there cannot be child nodes in the tree, so sifting is done.
if (index >= type(uint256).max / 2) return;
if (right < size) {
// the check guarantees that `left` and `right` are both valid uint64
uint64 lIndex = uint64(left);
uint64 rIndex = uint64(right);
uint256 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
uint256 rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value;
// Compute the indices of the potential child nodes
uint256 lIndex = 2 * index + 1;
uint256 rIndex = 2 * index + 2;
// Three cases:
// 1. Both children exist: sifting may continue on one of the branch (selection required)
// 2. Only left child exist: sifting may contineu on the left branch (no selection required)
// 3. Neither child exist: sifting is done
if (rIndex < size) {
uint256 lValue = self.tree.unsafeAccess(lIndex).value;
uint256 rValue = self.tree.unsafeAccess(rIndex).value;
if (comp(lValue, value) || comp(rValue, value)) {
uint64 index = uint64(comp(lValue, rValue).ternary(lIndex, rIndex));
_swap(self, pos, index);
_siftDown(self, size, index, value, comp);
uint256 cIndex = comp(lValue, rValue).ternary(lIndex, rIndex);
_swap(self, index, cIndex);
_siftDown(self, size, cIndex, value, comp);
}
} else if (left < size) {
// the check guarantees that `left` is a valid uint64
uint64 lIndex = uint64(left);
uint256 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
} else if (lIndex < size) {
uint256 lValue = self.tree.unsafeAccess(lIndex).value;
if (comp(lValue, value)) {
_swap(self, pos, lIndex);
_swap(self, index, lIndex);
_siftDown(self, size, lIndex, value, comp);
}
}
}
/**
* @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a
* @dev Perform heap maintenance on `self`, starting at `index` (with the `value`), using `comp` as a
* comparator, and moving toward the root of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. `value`
* could be extracted from `self` and `pos`, but that would require redundant storage read. These parameters are not
* could be extracted from `self` and `index`, but that would require redundant storage read. These parameters are not
* verified. It is the caller role to make sure the parameters are correct.
*/
function _siftUp(
Uint256Heap storage self,
uint64 pos,
uint256 index,
uint256 value,
function(uint256, uint256) view returns (bool) comp
) private {
unchecked {
while (pos > 0) {
uint64 parent = (pos - 1) / 2;
uint256 parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value;
while (index > 0) {
uint256 parentIndex = (index - 1) / 2;
uint256 parentValue = self.tree.unsafeAccess(parentIndex).value;
if (comp(parentValue, value)) break;
_swap(self, pos, parent);
pos = parent;
_swap(self, index, parentIndex);
index = parentIndex;
}
}
}
function _unsafeNodeAccess(
Uint256Heap storage self,
uint64 pos
) private pure returns (Uint256HeapNode storage result) {
assembly ("memory-safe") {
mstore(0x00, self.slot)
result.slot := add(keccak256(0x00, 0x20), mul(pos, 2))
}
}
/**
* @dev Binary heap that supports values of type uint208.
*
* Each element of that structure uses 1 storage slots.
*/
struct Uint208Heap {
Uint208HeapNode[] data;
}
/**
* @dev Internal node type for Uint208Heap. Stores a value of type uint208.
*/
struct Uint208HeapNode {
uint208 value;
uint24 index; // position -> value
uint24 lookup; // value -> position
}
/**
* @dev Lookup the root element of the heap.
*/
function peek(Uint208Heap storage self) internal view returns (uint208) {
// self.data[0] will `ARRAY_ACCESS_OUT_OF_BOUNDS` panic if heap is empty.
return _unsafeNodeAccess(self, self.data[0].index).value;
}
/**
* @dev Remove (and return) the root element for the heap using the default comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function pop(Uint208Heap storage self) internal returns (uint208) {
return pop(self, Comparators.lt);
}
/**
* @dev Remove (and return) the root element for the heap using the provided comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function pop(
Uint208Heap storage self,
function(uint256, uint256) view returns (bool) comp
) internal returns (uint208) {
unchecked {
uint24 size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
uint24 last = size - 1;
// get root location (in the data array) and value
Uint208HeapNode storage rootNode = _unsafeNodeAccess(self, 0);
uint24 rootIdx = rootNode.index;
Uint208HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx);
Uint208HeapNode storage lastNode = _unsafeNodeAccess(self, last);
uint208 rootDataValue = rootData.value;
// if root is not the last element of the data array (that will get popped), reorder the data array.
if (rootIdx != last) {
// get details about the value stored in the last element of the array (that will get popped)
uint24 lastDataIdx = lastNode.lookup;
uint208 lastDataValue = lastNode.value;
// copy these values to the location of the root (that is safe, and that we no longer use)
rootData.value = lastDataValue;
rootData.lookup = lastDataIdx;
// update the tree node that used to point to that last element (value now located where the root was)
_unsafeNodeAccess(self, lastDataIdx).index = rootIdx;
}
// get last leaf location (in the data array) and value
uint24 lastIdx = lastNode.index;
uint208 lastValue = _unsafeNodeAccess(self, lastIdx).value;
// move the last leaf to the root, pop last leaf ...
rootNode.index = lastIdx;
_unsafeNodeAccess(self, lastIdx).lookup = 0;
self.data.pop();
// ... and heapify
_siftDown(self, last, 0, lastValue, comp);
// return root value
return rootDataValue;
}
}
/**
* @dev Insert a new element in the heap using the default comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function insert(Uint208Heap storage self, uint208 value) internal {
insert(self, value, Comparators.lt);
}
/**
* @dev Insert a new element in the heap using the provided comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function insert(
Uint208Heap storage self,
uint208 value,
function(uint256, uint256) view returns (bool) comp
) internal {
uint24 size = length(self);
if (size == type(uint24).max) Panic.panic(Panic.RESOURCE_ERROR);
self.data.push(Uint208HeapNode({index: size, lookup: size, value: value}));
_siftUp(self, size, value, comp);
}
/**
* @dev Return the root element for the heap, and replace it with a new value, using the default comparator.
* This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function replace(Uint208Heap storage self, uint208 newValue) internal returns (uint208) {
return replace(self, newValue, Comparators.lt);
}
/**
* @dev Return the root element for the heap, and replace it with a new value, using the provided comparator.
* This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function replace(
Uint208Heap storage self,
uint208 newValue,
function(uint256, uint256) view returns (bool) comp
) internal returns (uint208) {
uint24 size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
// position of the node that holds the data for the root
uint24 rootIdx = _unsafeNodeAccess(self, 0).index;
// storage pointer to the node that holds the data for the root
Uint208HeapNode storage rootData = _unsafeNodeAccess(self, rootIdx);
// cache old value and replace it
uint208 oldValue = rootData.value;
rootData.value = newValue;
// re-heapify
_siftDown(self, size, 0, newValue, comp);
// return old root value
return oldValue;
}
/**
* @dev Returns the number of elements in the heap.
*/
function length(Uint208Heap storage self) internal view returns (uint24) {
return self.data.length.toUint24();
}
/**
* @dev Removes all elements in the heap.
*/
function clear(Uint208Heap storage self) internal {
Uint208HeapNode[] storage data = self.data;
assembly ("memory-safe") {
sstore(data.slot, 0)
}
}
/**
* @dev Swap node `i` and `j` in the tree.
*/
function _swap(Uint208Heap storage self, uint24 i, uint24 j) private {
Uint208HeapNode storage ni = _unsafeNodeAccess(self, i);
Uint208HeapNode storage nj = _unsafeNodeAccess(self, j);
uint24 ii = ni.index;
uint24 jj = nj.index;
// update pointers to the data (swap the value)
ni.index = jj;
nj.index = ii;
// update lookup pointers for consistency
_unsafeNodeAccess(self, ii).lookup = j;
_unsafeNodeAccess(self, jj).lookup = i;
}
/**
* @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a
* comparator, and moving toward the leaves of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. `length`
* and `value` could be extracted from `self` and `pos`, but that would require redundant storage read. These
* parameters are not verified. It is the caller role to make sure the parameters are correct.
*/
function _siftDown(
Uint208Heap storage self,
uint24 size,
uint24 pos,
uint208 value,
function(uint256, uint256) view returns (bool) comp
) private {
uint256 left = 2 * pos + 1; // this could overflow uint24
uint256 right = 2 * pos + 2; // this could overflow uint24
if (right < size) {
// the check guarantees that `left` and `right` are both valid uint24
uint24 lIndex = uint24(left);
uint24 rIndex = uint24(right);
uint208 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
uint208 rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value;
if (comp(lValue, value) || comp(rValue, value)) {
uint24 index = uint24(comp(lValue, rValue).ternary(lIndex, rIndex));
_swap(self, pos, index);
_siftDown(self, size, index, value, comp);
}
} else if (left < size) {
// the check guarantees that `left` is a valid uint24
uint24 lIndex = uint24(left);
uint208 lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
if (comp(lValue, value)) {
_swap(self, pos, lIndex);
_siftDown(self, size, lIndex, value, comp);
}
}
}
/**
* @dev Perform heap maintenance on `self`, starting at position `pos` (with the `value`), using `comp` as a
* comparator, and moving toward the root of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. `value`
* could be extracted from `self` and `pos`, but that would require redundant storage read. These parameters are not
* verified. It is the caller role to make sure the parameters are correct.
*/
function _siftUp(
Uint208Heap storage self,
uint24 pos,
uint208 value,
function(uint256, uint256) view returns (bool) comp
) private {
unchecked {
while (pos > 0) {
uint24 parent = (pos - 1) / 2;
uint208 parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value;
if (comp(parentValue, value)) break;
_swap(self, pos, parent);
pos = parent;
}
}
}
function _unsafeNodeAccess(
Uint208Heap storage self,
uint24 pos
) private pure returns (Uint208HeapNode storage result) {
assembly ("memory-safe") {
mstore(0x00, self.slot)
result.slot := add(keccak256(0x00, 0x20), pos)
}
}
}

View File

@ -37,7 +37,6 @@ for (const [file, template] of Object.entries({
'utils/structs/Checkpoints.sol': './templates/Checkpoints.js',
'utils/structs/EnumerableSet.sol': './templates/EnumerableSet.js',
'utils/structs/EnumerableMap.sol': './templates/EnumerableMap.js',
'utils/structs/Heap.sol': './templates/Heap.js',
'utils/SlotDerivation.sol': './templates/SlotDerivation.js',
'utils/StorageSlot.sol': './templates/StorageSlot.js',
'utils/Arrays.sol': './templates/Arrays.js',
@ -50,7 +49,6 @@ for (const [file, template] of Object.entries({
// Tests
for (const [file, template] of Object.entries({
'utils/structs/Checkpoints.t.sol': './templates/Checkpoints.t.js',
'utils/structs/Heap.t.sol': './templates/Heap.t.js',
'utils/Packing.t.sol': './templates/Packing.t.js',
'utils/SlotDerivation.t.sol': './templates/SlotDerivation.t.js',
})) {

View File

@ -1,327 +0,0 @@
const format = require('../format-lines');
const { TYPES } = require('./Heap.opts');
const { capitalize } = require('../../helpers');
/* eslint-disable max-len */
const header = `\
pragma solidity ^0.8.20;
import {Math} from "../math/Math.sol";
import {SafeCast} from "../math/SafeCast.sol";
import {Comparators} from "../Comparators.sol";
import {Panic} from "../Panic.sol";
/**
* @dev Library for managing https://en.wikipedia.org/wiki/Binary_heap[binary heap] that can be used as
* https://en.wikipedia.org/wiki/Priority_queue[priority queue].
*
* Heaps are represented as an array of Node objects. This array stores two overlapping structures:
* * A tree structure where the first element (index 0) is the root, and where the node at index i is the child of the
* node at index (i-1)/2 and the father of nodes at index 2*i+1 and 2*i+2. Each node stores the index (in the array)
* where the corresponding value is stored.
* * A list of payloads values where each index contains a value and a lookup index. The type of the value depends on
* the variant being used. The lookup is the index of the node (in the tree) that points to this value.
*
* Some invariants:
* \`\`\`
* i == heap.data[heap.data[i].index].lookup // for all indices i
* i == heap.data[heap.data[i].lookup].index // for all indices i
* \`\`\`
*
* The structure is ordered so that each node is bigger than its parent. An immediate consequence is that the
* highest priority value is the one at the root. This value can be looked up in constant time (O(1)) at
* \`heap.data[heap.data[0].index].value\`
*
* The structure is designed to perform the following operations with the corresponding complexities:
*
* * peek (get the highest priority value): O(1)
* * insert (insert a value): O(log(n))
* * pop (remove the highest priority value): O(log(n))
* * replace (replace the highest priority value with a new value): O(log(n))
* * length (get the number of elements): O(1)
* * clear (remove all elements): O(1)
*/
`;
const generate = ({ struct, node, valueType, indexType, blockSize }) => `\
/**
* @dev Binary heap that supports values of type ${valueType}.
*
* Each element of that structure uses ${blockSize} storage slots.
*/
struct ${struct} {
${node}[] data;
}
/**
* @dev Internal node type for ${struct}. Stores a value of type ${valueType}.
*/
struct ${node} {
${valueType} value;
${indexType} index; // position -> value
${indexType} lookup; // value -> position
}
/**
* @dev Lookup the root element of the heap.
*/
function peek(${struct} storage self) internal view returns (${valueType}) {
// self.data[0] will \`ARRAY_ACCESS_OUT_OF_BOUNDS\` panic if heap is empty.
return _unsafeNodeAccess(self, self.data[0].index).value;
}
/**
* @dev Remove (and return) the root element for the heap using the default comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function pop(${struct} storage self) internal returns (${valueType}) {
return pop(self, Comparators.lt);
}
/**
* @dev Remove (and return) the root element for the heap using the provided comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function pop(
${struct} storage self,
function(uint256, uint256) view returns (bool) comp
) internal returns (${valueType}) {
unchecked {
${indexType} size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
${indexType} last = size - 1;
// get root location (in the data array) and value
${node} storage rootNode = _unsafeNodeAccess(self, 0);
${indexType} rootIdx = rootNode.index;
${node} storage rootData = _unsafeNodeAccess(self, rootIdx);
${node} storage lastNode = _unsafeNodeAccess(self, last);
${valueType} rootDataValue = rootData.value;
// if root is not the last element of the data array (that will get popped), reorder the data array.
if (rootIdx != last) {
// get details about the value stored in the last element of the array (that will get popped)
${indexType} lastDataIdx = lastNode.lookup;
${valueType} lastDataValue = lastNode.value;
// copy these values to the location of the root (that is safe, and that we no longer use)
rootData.value = lastDataValue;
rootData.lookup = lastDataIdx;
// update the tree node that used to point to that last element (value now located where the root was)
_unsafeNodeAccess(self, lastDataIdx).index = rootIdx;
}
// get last leaf location (in the data array) and value
${indexType} lastIdx = lastNode.index;
${valueType} lastValue = _unsafeNodeAccess(self, lastIdx).value;
// move the last leaf to the root, pop last leaf ...
rootNode.index = lastIdx;
_unsafeNodeAccess(self, lastIdx).lookup = 0;
self.data.pop();
// ... and heapify
_siftDown(self, last, 0, lastValue, comp);
// return root value
return rootDataValue;
}
}
/**
* @dev Insert a new element in the heap using the default comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function insert(${struct} storage self, ${valueType} value) internal {
insert(self, value, Comparators.lt);
}
/**
* @dev Insert a new element in the heap using the provided comparator.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function insert(
${struct} storage self,
${valueType} value,
function(uint256, uint256) view returns (bool) comp
) internal {
${indexType} size = length(self);
if (size == type(${indexType}).max) Panic.panic(Panic.RESOURCE_ERROR);
self.data.push(${struct}Node({index: size, lookup: size, value: value}));
_siftUp(self, size, value, comp);
}
/**
* @dev Return the root element for the heap, and replace it with a new value, using the default comparator.
* This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function replace(${struct} storage self, ${valueType} newValue) internal returns (${valueType}) {
return replace(self, newValue, Comparators.lt);
}
/**
* @dev Return the root element for the heap, and replace it with a new value, using the provided comparator.
* This is equivalent to using {pop} and {insert}, but requires only one rebalancing operation.
*
* NOTE: All inserting and removal from a heap should always be done using the same comparator. Mixing comparator
* during the lifecycle of a heap will result in undefined behavior.
*/
function replace(
${struct} storage self,
${valueType} newValue,
function(uint256, uint256) view returns (bool) comp
) internal returns (${valueType}) {
${indexType} size = length(self);
if (size == 0) Panic.panic(Panic.EMPTY_ARRAY_POP);
// position of the node that holds the data for the root
${indexType} rootIdx = _unsafeNodeAccess(self, 0).index;
// storage pointer to the node that holds the data for the root
${node} storage rootData = _unsafeNodeAccess(self, rootIdx);
// cache old value and replace it
${valueType} oldValue = rootData.value;
rootData.value = newValue;
// re-heapify
_siftDown(self, size, 0, newValue, comp);
// return old root value
return oldValue;
}
/**
* @dev Returns the number of elements in the heap.
*/
function length(${struct} storage self) internal view returns (${indexType}) {
return self.data.length.to${capitalize(indexType)}();
}
/**
* @dev Removes all elements in the heap.
*/
function clear(${struct} storage self) internal {
${struct}Node[] storage data = self.data;
assembly ("memory-safe") {
sstore(data.slot, 0)
}
}
/**
* @dev Swap node \`i\` and \`j\` in the tree.
*/
function _swap(${struct} storage self, ${indexType} i, ${indexType} j) private {
${node} storage ni = _unsafeNodeAccess(self, i);
${node} storage nj = _unsafeNodeAccess(self, j);
${indexType} ii = ni.index;
${indexType} jj = nj.index;
// update pointers to the data (swap the value)
ni.index = jj;
nj.index = ii;
// update lookup pointers for consistency
_unsafeNodeAccess(self, ii).lookup = j;
_unsafeNodeAccess(self, jj).lookup = i;
}
/**
* @dev Perform heap maintenance on \`self\`, starting at position \`pos\` (with the \`value\`), using \`comp\` as a
* comparator, and moving toward the leaves of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. \`length\`
* and \`value\` could be extracted from \`self\` and \`pos\`, but that would require redundant storage read. These
* parameters are not verified. It is the caller role to make sure the parameters are correct.
*/
function _siftDown(
${struct} storage self,
${indexType} size,
${indexType} pos,
${valueType} value,
function(uint256, uint256) view returns (bool) comp
) private {
uint256 left = 2 * pos + 1; // this could overflow ${indexType}
uint256 right = 2 * pos + 2; // this could overflow ${indexType}
if (right < size) {
// the check guarantees that \`left\` and \`right\` are both valid ${indexType}
${indexType} lIndex = ${indexType}(left);
${indexType} rIndex = ${indexType}(right);
${valueType} lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
${valueType} rValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, rIndex).index).value;
if (comp(lValue, value) || comp(rValue, value)) {
${indexType} index = ${indexType}(comp(lValue, rValue).ternary(lIndex, rIndex));
_swap(self, pos, index);
_siftDown(self, size, index, value, comp);
}
} else if (left < size) {
// the check guarantees that \`left\` is a valid ${indexType}
${indexType} lIndex = ${indexType}(left);
${valueType} lValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, lIndex).index).value;
if (comp(lValue, value)) {
_swap(self, pos, lIndex);
_siftDown(self, size, lIndex, value, comp);
}
}
}
/**
* @dev Perform heap maintenance on \`self\`, starting at position \`pos\` (with the \`value\`), using \`comp\` as a
* comparator, and moving toward the root of the underlying tree.
*
* NOTE: This is a private function that is called in a trusted context with already cached parameters. \`value\`
* could be extracted from \`self\` and \`pos\`, but that would require redundant storage read. These parameters are not
* verified. It is the caller role to make sure the parameters are correct.
*/
function _siftUp(
${struct} storage self,
${indexType} pos,
${valueType} value,
function(uint256, uint256) view returns (bool) comp
) private {
unchecked {
while (pos > 0) {
${indexType} parent = (pos - 1) / 2;
${valueType} parentValue = _unsafeNodeAccess(self, _unsafeNodeAccess(self, parent).index).value;
if (comp(parentValue, value)) break;
_swap(self, pos, parent);
pos = parent;
}
}
}
function _unsafeNodeAccess(
${struct} storage self,
${indexType} pos
) private pure returns (${node} storage result) {
assembly ("memory-safe") {
mstore(0x00, self.slot)
result.slot := add(keccak256(0x00, 0x20), ${blockSize == 1 ? 'pos' : `mul(pos, ${blockSize})`})
}
}
`;
// GENERATE
module.exports = format(
header.trimEnd(),
'library Heap {',
format(
[].concat(
'using Math for *;',
'using SafeCast for *;',
'',
TYPES.map(type => generate(type)),
),
).trimEnd(),
'}',
);

View File

@ -1,13 +0,0 @@
const makeType = (valueSize, indexSize) => ({
struct: `Uint${valueSize}Heap`,
node: `Uint${valueSize}HeapNode`,
valueSize,
valueType: `uint${valueSize}`,
indexSize,
indexType: `uint${indexSize}`,
blockSize: Math.ceil((valueSize + 2 * indexSize) / 256),
});
module.exports = {
TYPES: [makeType(256, 64), makeType(208, 24)],
};

View File

@ -1,89 +0,0 @@
const format = require('../format-lines');
const { TYPES } = require('./Heap.opts');
/* eslint-disable max-len */
const header = `\
pragma solidity ^0.8.20;
import {Test} from "forge-std/Test.sol";
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
import {Heap} from "@openzeppelin/contracts/utils/structs/Heap.sol";
import {Comparators} from "@openzeppelin/contracts/utils/Comparators.sol";
`;
const generate = ({ struct, valueType }) => `\
contract ${struct}Test is Test {
using Heap for Heap.${struct};
Heap.${struct} internal heap;
function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal {
for (uint32 i = 0; i < heap.length(); ++i) {
// lookups
assertEq(i, heap.data[heap.data[i].index].lookup);
assertEq(i, heap.data[heap.data[i].lookup].index);
// ordering: each node has a value bigger then its parent
if (i > 0)
assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value));
}
}
function testFuzz(${valueType}[] calldata input) public {
vm.assume(input.length < 0x20);
assertEq(heap.length(), 0);
uint256 min = type(uint256).max;
for (uint256 i = 0; i < input.length; ++i) {
heap.insert(input[i]);
assertEq(heap.length(), i + 1);
_validateHeap(Comparators.lt);
min = Math.min(min, input[i]);
assertEq(heap.peek(), min);
}
uint256 max = 0;
for (uint256 i = 0; i < input.length; ++i) {
${valueType} top = heap.peek();
${valueType} pop = heap.pop();
assertEq(heap.length(), input.length - i - 1);
_validateHeap(Comparators.lt);
assertEq(pop, top);
assertGe(pop, max);
max = pop;
}
}
function testFuzzGt(${valueType}[] calldata input) public {
vm.assume(input.length < 0x20);
assertEq(heap.length(), 0);
uint256 max = 0;
for (uint256 i = 0; i < input.length; ++i) {
heap.insert(input[i], Comparators.gt);
assertEq(heap.length(), i + 1);
_validateHeap(Comparators.gt);
max = Math.max(max, input[i]);
assertEq(heap.peek(), max);
}
uint256 min = type(uint256).max;
for (uint256 i = 0; i < input.length; ++i) {
${valueType} top = heap.peek();
${valueType} pop = heap.pop(Comparators.gt);
assertEq(heap.length(), input.length - i - 1);
_validateHeap(Comparators.gt);
assertEq(pop, top);
assertLe(pop, min);
min = pop;
}
}
}
`;
// GENERATE
module.exports = format(header, ...TYPES.map(type => generate(type)));

View File

@ -1,5 +1,4 @@
// SPDX-License-Identifier: MIT
// This file was procedurally generated from scripts/generate/templates/Heap.t.js.
pragma solidity ^0.8.20;
@ -14,14 +13,8 @@ contract Uint256HeapTest is Test {
Heap.Uint256Heap internal heap;
function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal {
for (uint32 i = 0; i < heap.length(); ++i) {
// lookups
assertEq(i, heap.data[heap.data[i].index].lookup);
assertEq(i, heap.data[heap.data[i].lookup].index);
// ordering: each node has a value bigger then its parent
if (i > 0)
assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value));
for (uint32 i = 1; i < heap.length(); ++i) {
assertFalse(comp(heap.tree[i], heap.tree[(i - 1) / 2]));
}
}
@ -79,75 +72,3 @@ contract Uint256HeapTest is Test {
}
}
}
contract Uint208HeapTest is Test {
using Heap for Heap.Uint208Heap;
Heap.Uint208Heap internal heap;
function _validateHeap(function(uint256, uint256) view returns (bool) comp) internal {
for (uint32 i = 0; i < heap.length(); ++i) {
// lookups
assertEq(i, heap.data[heap.data[i].index].lookup);
assertEq(i, heap.data[heap.data[i].lookup].index);
// ordering: each node has a value bigger then its parent
if (i > 0)
assertFalse(comp(heap.data[heap.data[i].index].value, heap.data[heap.data[(i - 1) / 2].index].value));
}
}
function testFuzz(uint208[] calldata input) public {
vm.assume(input.length < 0x20);
assertEq(heap.length(), 0);
uint256 min = type(uint256).max;
for (uint256 i = 0; i < input.length; ++i) {
heap.insert(input[i]);
assertEq(heap.length(), i + 1);
_validateHeap(Comparators.lt);
min = Math.min(min, input[i]);
assertEq(heap.peek(), min);
}
uint256 max = 0;
for (uint256 i = 0; i < input.length; ++i) {
uint208 top = heap.peek();
uint208 pop = heap.pop();
assertEq(heap.length(), input.length - i - 1);
_validateHeap(Comparators.lt);
assertEq(pop, top);
assertGe(pop, max);
max = pop;
}
}
function testFuzzGt(uint208[] calldata input) public {
vm.assume(input.length < 0x20);
assertEq(heap.length(), 0);
uint256 max = 0;
for (uint256 i = 0; i < input.length; ++i) {
heap.insert(input[i], Comparators.gt);
assertEq(heap.length(), i + 1);
_validateHeap(Comparators.gt);
max = Math.max(max, input[i]);
assertEq(heap.peek(), max);
}
uint256 min = type(uint256).max;
for (uint256 i = 0; i < input.length; ++i) {
uint208 top = heap.peek();
uint208 pop = heap.pop(Comparators.gt);
assertEq(heap.length(), input.length - i - 1);
_validateHeap(Comparators.gt);
assertEq(pop, top);
assertLe(pop, min);
min = pop;
}
}
}

View File

@ -3,8 +3,6 @@ const { expect } = require('chai');
const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
const { TYPES } = require('../../../scripts/generate/templates/Heap.opts');
async function fixture() {
const mock = await ethers.deployContract('$Heap');
return { mock };
@ -15,63 +13,48 @@ describe('Heap', function () {
Object.assign(this, await loadFixture(fixture));
});
for (const { struct, valueType } of TYPES) {
describe(struct, function () {
const popEvent = `return$pop_Heap_${struct}`;
const replaceEvent = `return$replace_Heap_${struct}_${valueType}`;
beforeEach(async function () {
this.helper = {
clear: (...args) => this.mock[`$clear_Heap_${struct}`](0, ...args),
insert: (...args) => this.mock[`$insert(uint256,${valueType})`](0, ...args),
replace: (...args) => this.mock[`$replace(uint256,${valueType})`](0, ...args),
length: (...args) => this.mock[`$length_Heap_${struct}`](0, ...args),
pop: (...args) => this.mock[`$pop_Heap_${struct}`](0, ...args),
peek: (...args) => this.mock[`$peek_Heap_${struct}`](0, ...args),
};
});
describe('Uint256Heap', function () {
it('starts empty', async function () {
expect(await this.helper.length()).to.equal(0n);
expect(await this.mock.$length(0)).to.equal(0n);
});
it('peek, pop and replace from empty', async function () {
await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.helper.replace(0n)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.mock.$peek(0)).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
await expect(this.mock.$pop(0)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.mock.$replace(0, 0n)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
});
it('clear', async function () {
await this.helper.insert(42n);
await this.mock.$insert(0, 42n);
expect(await this.helper.length()).to.equal(1n);
expect(await this.helper.peek()).to.equal(42n);
expect(await this.mock.$length(0)).to.equal(1n);
expect(await this.mock.$peek(0)).to.equal(42n);
await this.helper.clear();
await this.mock.$clear(0);
expect(await this.helper.length()).to.equal(0n);
await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
expect(await this.mock.$length(0)).to.equal(0n);
await expect(this.mock.$peek(0)).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
});
it('support duplicated items', async function () {
expect(await this.helper.length()).to.equal(0n);
expect(await this.mock.$length(0)).to.equal(0n);
// insert 5 times
await this.helper.insert(42n);
await this.helper.insert(42n);
await this.helper.insert(42n);
await this.helper.insert(42n);
await this.helper.insert(42n);
await this.mock.$insert(0, 42n);
await this.mock.$insert(0, 42n);
await this.mock.$insert(0, 42n);
await this.mock.$insert(0, 42n);
await this.mock.$insert(0, 42n);
// pop 5 times
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n);
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n);
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n);
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n);
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(42n);
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(42n);
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(42n);
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(42n);
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(42n);
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(42n);
// popping a 6th time panics
await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.mock.$pop(0)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
});
it('insert, pop and replace', async function () {
@ -97,35 +80,34 @@ describe('Heap', function () {
]) {
switch (op) {
case 'insert':
await this.helper.insert(value);
await this.mock.$insert(0, value);
heap.push(value);
heap.sort((a, b) => a - b);
break;
case 'pop':
if (heap.length == 0) {
await expect(this.helper.pop()).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.mock.$pop(0)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
} else {
await expect(this.helper.pop()).to.emit(this.mock, popEvent).withArgs(heap.shift());
await expect(this.mock.$pop(0)).to.emit(this.mock, 'return$pop').withArgs(heap.shift());
}
break;
case 'replace':
if (heap.length == 0) {
await expect(this.helper.replace(value)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
await expect(this.mock.$replace(0, value)).to.be.revertedWithPanic(PANIC_CODES.POP_ON_EMPTY_ARRAY);
} else {
await expect(this.helper.replace(value)).to.emit(this.mock, replaceEvent).withArgs(heap.shift());
await expect(this.mock.$replace(0, value)).to.emit(this.mock, 'return$replace').withArgs(heap.shift());
heap.push(value);
heap.sort((a, b) => a - b);
}
break;
}
expect(await this.helper.length()).to.equal(heap.length);
expect(await this.mock.$length(0)).to.equal(heap.length);
if (heap.length == 0) {
await expect(this.helper.peek()).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
await expect(this.mock.$peek(0)).to.be.revertedWithPanic(PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS);
} else {
expect(await this.helper.peek()).to.equal(heap[0]);
expect(await this.mock.$peek(0)).to.equal(heap[0]);
}
}
});
});
}
});