diff --git a/.changeset/eight-radios-check.md b/.changeset/eight-radios-check.md new file mode 100644 index 000000000..431c7fcb1 --- /dev/null +++ b/.changeset/eight-radios-check.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Checkpoints`: Add a new checkpoint variant `Checkpoint256` using `uint256` type for the value and key. diff --git a/contracts/utils/structs/Checkpoints.sol b/contracts/utils/structs/Checkpoints.sol index 431c9acb0..f8cfd20ee 100644 --- a/contracts/utils/structs/Checkpoints.sol +++ b/contracts/utils/structs/Checkpoints.sol @@ -19,6 +19,209 @@ library Checkpoints { */ error CheckpointUnorderedInsertion(); + struct Trace256 { + Checkpoint256[] _checkpoints; + } + + struct Checkpoint256 { + uint256 _key; + uint256 _value; + } + + /** + * @dev Pushes a (`key`, `value`) pair into a Trace256 so that it is stored as the checkpoint. + * + * Returns previous value and new value. + * + * IMPORTANT: Never accept `key` as a user input, since an arbitrary `type(uint256).max` key set will disable the + * library. + */ + function push( + Trace256 storage self, + uint256 key, + uint256 value + ) internal returns (uint256 oldValue, uint256 newValue) { + return _insert(self._checkpoints, key, value); + } + + /** + * @dev Returns the value in the first (oldest) checkpoint with key greater or equal than the search key, or zero if + * there is none. + */ + function lowerLookup(Trace256 storage self, uint256 key) internal view returns (uint256) { + uint256 len = self._checkpoints.length; + uint256 pos = _lowerBinaryLookup(self._checkpoints, key, 0, len); + return pos == len ? 0 : _unsafeAccess(self._checkpoints, pos)._value; + } + + /** + * @dev Returns the value in the last (most recent) checkpoint with key lower or equal than the search key, or zero + * if there is none. + */ + function upperLookup(Trace256 storage self, uint256 key) internal view returns (uint256) { + uint256 len = self._checkpoints.length; + uint256 pos = _upperBinaryLookup(self._checkpoints, key, 0, len); + return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; + } + + /** + * @dev Returns the value in the last (most recent) checkpoint with key lower or equal than the search key, or zero + * if there is none. + * + * NOTE: This is a variant of {upperLookup} that is optimized to find "recent" checkpoint (checkpoints with high + * keys). + */ + function upperLookupRecent(Trace256 storage self, uint256 key) internal view returns (uint256) { + uint256 len = self._checkpoints.length; + + uint256 low = 0; + uint256 high = len; + + if (len > 5) { + uint256 mid = len - Math.sqrt(len); + if (key < _unsafeAccess(self._checkpoints, mid)._key) { + high = mid; + } else { + low = mid + 1; + } + } + + uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high); + + return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; + } + + /** + * @dev Returns the value in the most recent checkpoint, or zero if there are no checkpoints. + */ + function latest(Trace256 storage self) internal view returns (uint256) { + uint256 pos = self._checkpoints.length; + return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value; + } + + /** + * @dev Returns whether there is a checkpoint in the structure (i.e. it is not empty), and if so the key and value + * in the most recent checkpoint. + */ + function latestCheckpoint(Trace256 storage self) internal view returns (bool exists, uint256 _key, uint256 _value) { + uint256 pos = self._checkpoints.length; + if (pos == 0) { + return (false, 0, 0); + } else { + Checkpoint256 storage ckpt = _unsafeAccess(self._checkpoints, pos - 1); + return (true, ckpt._key, ckpt._value); + } + } + + /** + * @dev Returns the number of checkpoints. + */ + function length(Trace256 storage self) internal view returns (uint256) { + return self._checkpoints.length; + } + + /** + * @dev Returns checkpoint at given position. + */ + function at(Trace256 storage self, uint32 pos) internal view returns (Checkpoint256 memory) { + return self._checkpoints[pos]; + } + + /** + * @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint, + * or by updating the last one. + */ + function _insert( + Checkpoint256[] storage self, + uint256 key, + uint256 value + ) private returns (uint256 oldValue, uint256 newValue) { + uint256 pos = self.length; + + if (pos > 0) { + Checkpoint256 storage last = _unsafeAccess(self, pos - 1); + uint256 lastKey = last._key; + uint256 lastValue = last._value; + + // Checkpoint keys must be non-decreasing. + if (lastKey > key) { + revert CheckpointUnorderedInsertion(); + } + + // Update or push new checkpoint + if (lastKey == key) { + last._value = value; + } else { + self.push(Checkpoint256({_key: key, _value: value})); + } + return (lastValue, value); + } else { + self.push(Checkpoint256({_key: key, _value: value})); + return (0, value); + } + } + + /** + * @dev Return the index of the first (oldest) checkpoint with key strictly bigger than the search key, or `high` + * if there is none. `low` and `high` define a section where to do the search, with inclusive `low` and exclusive + * `high`. + * + * WARNING: `high` should not be greater than the array's length. + */ + function _upperBinaryLookup( + Checkpoint256[] storage self, + uint256 key, + uint256 low, + uint256 high + ) private view returns (uint256) { + while (low < high) { + uint256 mid = Math.average(low, high); + if (_unsafeAccess(self, mid)._key > key) { + high = mid; + } else { + low = mid + 1; + } + } + return high; + } + + /** + * @dev Return the index of the first (oldest) checkpoint with key greater or equal than the search key, or `high` + * if there is none. `low` and `high` define a section where to do the search, with inclusive `low` and exclusive + * `high`. + * + * WARNING: `high` should not be greater than the array's length. + */ + function _lowerBinaryLookup( + Checkpoint256[] storage self, + uint256 key, + uint256 low, + uint256 high + ) private view returns (uint256) { + while (low < high) { + uint256 mid = Math.average(low, high); + if (_unsafeAccess(self, mid)._key < key) { + low = mid + 1; + } else { + high = mid; + } + } + return high; + } + + /** + * @dev Access an element of the array without performing bounds check. The position is assumed to be within bounds. + */ + function _unsafeAccess( + Checkpoint256[] storage self, + uint256 pos + ) private pure returns (Checkpoint256 storage result) { + assembly { + mstore(0, self.slot) + result.slot := add(keccak256(0, 0x20), mul(pos, 2)) + } + } + struct Trace224 { Checkpoint224[] _checkpoints; } diff --git a/scripts/generate/templates/Checkpoints.js b/scripts/generate/templates/Checkpoints.js index 0aaa18107..000c181b1 100644 --- a/scripts/generate/templates/Checkpoints.js +++ b/scripts/generate/templates/Checkpoints.js @@ -223,7 +223,7 @@ function _unsafeAccess( ) private pure returns (${opts.checkpointTypeName} storage result) { assembly { mstore(0, self.slot) - result.slot := add(keccak256(0, 0x20), pos) + result.slot := add(keccak256(0, 0x20), ${opts.checkpointSize === 1 ? 'pos' : `mul(pos, ${opts.checkpointSize})`}) } } `; diff --git a/scripts/generate/templates/Checkpoints.opts.js b/scripts/generate/templates/Checkpoints.opts.js index 08b7b910b..c0920290a 100644 --- a/scripts/generate/templates/Checkpoints.opts.js +++ b/scripts/generate/templates/Checkpoints.opts.js @@ -1,11 +1,12 @@ // OPTIONS -const VALUE_SIZES = [224, 208, 160]; +const VALUE_SIZES = [256, 224, 208, 160]; const defaultOpts = size => ({ historyTypeName: `Trace${size}`, checkpointTypeName: `Checkpoint${size}`, checkpointFieldName: '_checkpoints', - keyTypeName: `uint${256 - size}`, + checkpointSize: size < 256 ? 1 : 2, + keyTypeName: size < 256 ? `uint${256 - size}` : 'uint256', keyFieldName: '_key', valueTypeName: `uint${size}`, valueFieldName: '_value', diff --git a/scripts/generate/templates/Checkpoints.t.js b/scripts/generate/templates/Checkpoints.t.js index 77a9cd31a..23d9466d2 100644 --- a/scripts/generate/templates/Checkpoints.t.js +++ b/scripts/generate/templates/Checkpoints.t.js @@ -24,7 +24,11 @@ Checkpoints.${opts.historyTypeName} internal _ckpts; function _bound${capitalize(opts.keyTypeName)}(${opts.keyTypeName} x, ${opts.keyTypeName} min, ${ opts.keyTypeName } max) internal pure returns (${opts.keyTypeName}) { - return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max))); + return ${ + opts.keyTypeName === 'uint256' + ? 'bound(x, min, max)' + : `SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max)))` + }; } function _prepareKeys(${opts.keyTypeName}[] memory keys, ${opts.keyTypeName} maxSpread) internal pure { diff --git a/test/utils/structs/Checkpoints.t.sol b/test/utils/structs/Checkpoints.t.sol index 74d8fb8b8..6ca7faa46 100644 --- a/test/utils/structs/Checkpoints.t.sol +++ b/test/utils/structs/Checkpoints.t.sol @@ -7,6 +7,114 @@ import {Test} from "forge-std/Test.sol"; import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; import {Checkpoints} from "@openzeppelin/contracts/utils/structs/Checkpoints.sol"; +contract CheckpointsTrace256Test is Test { + using Checkpoints for Checkpoints.Trace256; + + // Maximum gap between keys used during the fuzzing tests: the `_prepareKeys` function will make sure that + // key#n+1 is in the [key#n, key#n + _KEY_MAX_GAP] range. + uint8 internal constant _KEY_MAX_GAP = 64; + + Checkpoints.Trace256 internal _ckpts; + + // helpers + function _boundUint256(uint256 x, uint256 min, uint256 max) internal pure returns (uint256) { + return bound(x, min, max); + } + + function _prepareKeys(uint256[] memory keys, uint256 maxSpread) internal pure { + uint256 lastKey = 0; + for (uint256 i = 0; i < keys.length; ++i) { + uint256 key = _boundUint256(keys[i], lastKey, lastKey + maxSpread); + keys[i] = key; + lastKey = key; + } + } + + function _assertLatestCheckpoint(bool exist, uint256 key, uint256 value) internal view { + (bool _exist, uint256 _key, uint256 _value) = _ckpts.latestCheckpoint(); + assertEq(_exist, exist); + assertEq(_key, key); + assertEq(_value, value); + } + + // tests + function testPush(uint256[] memory keys, uint256[] memory values, uint256 pastKey) public { + vm.assume(values.length > 0 && values.length <= keys.length); + _prepareKeys(keys, _KEY_MAX_GAP); + + // initial state + assertEq(_ckpts.length(), 0); + assertEq(_ckpts.latest(), 0); + _assertLatestCheckpoint(false, 0, 0); + + uint256 duplicates = 0; + for (uint256 i = 0; i < keys.length; ++i) { + uint256 key = keys[i]; + uint256 value = values[i % values.length]; + if (i > 0 && key == keys[i - 1]) ++duplicates; + + // push + _ckpts.push(key, value); + + // check length & latest + assertEq(_ckpts.length(), i + 1 - duplicates); + assertEq(_ckpts.latest(), value); + _assertLatestCheckpoint(true, key, value); + } + + if (keys.length > 0) { + uint256 lastKey = keys[keys.length - 1]; + if (lastKey > 0) { + pastKey = _boundUint256(pastKey, 0, lastKey - 1); + + vm.expectRevert(); + this.push(pastKey, values[keys.length % values.length]); + } + } + } + + // used to test reverts + function push(uint256 key, uint256 value) external { + _ckpts.push(key, value); + } + + function testLookup(uint256[] memory keys, uint256[] memory values, uint256 lookup) public { + vm.assume(values.length > 0 && values.length <= keys.length); + _prepareKeys(keys, _KEY_MAX_GAP); + + uint256 lastKey = keys.length == 0 ? 0 : keys[keys.length - 1]; + lookup = _boundUint256(lookup, 0, lastKey + _KEY_MAX_GAP); + + uint256 upper = 0; + uint256 lower = 0; + uint256 lowerKey = type(uint256).max; + for (uint256 i = 0; i < keys.length; ++i) { + uint256 key = keys[i]; + uint256 value = values[i % values.length]; + + // push + _ckpts.push(key, value); + + // track expected result of lookups + if (key <= lookup) { + upper = value; + } + // find the first key that is not smaller than the lookup key + if (key >= lookup && (i == 0 || keys[i - 1] < lookup)) { + lowerKey = key; + } + if (key == lowerKey) { + lower = value; + } + } + + // check lookup + assertEq(_ckpts.lowerLookup(lookup), lower); + assertEq(_ckpts.upperLookup(lookup), upper); + assertEq(_ckpts.upperLookupRecent(lookup), upper); + } +} + contract CheckpointsTrace224Test is Test { using Checkpoints for Checkpoints.Trace224; diff --git a/test/utils/structs/Checkpoints.test.js b/test/utils/structs/Checkpoints.test.js index fd22544b9..fe055a778 100644 --- a/test/utils/structs/Checkpoints.test.js +++ b/test/utils/structs/Checkpoints.test.js @@ -2,23 +2,24 @@ const { ethers } = require('hardhat'); const { expect } = require('chai'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); -const { VALUE_SIZES } = require('../../../scripts/generate/templates/Checkpoints.opts'); +const { OPTS } = require('../../../scripts/generate/templates/Checkpoints.opts'); describe('Checkpoints', function () { - for (const length of VALUE_SIZES) { - describe(`Trace${length}`, function () { + for (const opt of OPTS) { + describe(opt.historyTypeName, function () { const fixture = async () => { const mock = await ethers.deployContract('$Checkpoints'); const methods = { - at: (...args) => mock.getFunction(`$at_Checkpoints_Trace${length}`)(0, ...args), - latest: (...args) => mock.getFunction(`$latest_Checkpoints_Trace${length}`)(0, ...args), - latestCheckpoint: (...args) => mock.getFunction(`$latestCheckpoint_Checkpoints_Trace${length}`)(0, ...args), - length: (...args) => mock.getFunction(`$length_Checkpoints_Trace${length}`)(0, ...args), - push: (...args) => mock.getFunction(`$push(uint256,uint${256 - length},uint${length})`)(0, ...args), - lowerLookup: (...args) => mock.getFunction(`$lowerLookup(uint256,uint${256 - length})`)(0, ...args), - upperLookup: (...args) => mock.getFunction(`$upperLookup(uint256,uint${256 - length})`)(0, ...args), + at: (...args) => mock.getFunction(`$at_Checkpoints_${opt.historyTypeName}`)(0, ...args), + latest: (...args) => mock.getFunction(`$latest_Checkpoints_${opt.historyTypeName}`)(0, ...args), + latestCheckpoint: (...args) => + mock.getFunction(`$latestCheckpoint_Checkpoints_${opt.historyTypeName}`)(0, ...args), + length: (...args) => mock.getFunction(`$length_Checkpoints_${opt.historyTypeName}`)(0, ...args), + push: (...args) => mock.getFunction(`$push(uint256,${opt.keyTypeName},${opt.valueTypeName})`)(0, ...args), + lowerLookup: (...args) => mock.getFunction(`$lowerLookup(uint256,${opt.keyTypeName})`)(0, ...args), + upperLookup: (...args) => mock.getFunction(`$upperLookup(uint256,${opt.keyTypeName})`)(0, ...args), upperLookupRecent: (...args) => - mock.getFunction(`$upperLookupRecent(uint256,uint${256 - length})`)(0, ...args), + mock.getFunction(`$upperLookupRecent(uint256,${opt.keyTypeName})`)(0, ...args), }; return { mock, methods };