From 5e3ba29b088f77483d1284b69d5c5927593ab457 Mon Sep 17 00:00:00 2001 From: Renan Souza Date: Wed, 27 Mar 2024 10:17:19 -0300 Subject: [PATCH] Procedurally generate Arrays.sol (#4859) Co-authored-by: ernestognw Co-authored-by: Hadrien Croubois --- contracts/utils/Arrays.sol | 24 +- scripts/generate/run.js | 1 + scripts/generate/templates/Arrays.js | 386 ++++++++++++++++++++++ scripts/generate/templates/Arrays.opts.js | 3 + test/utils/Arrays.test.js | 21 +- 5 files changed, 416 insertions(+), 19 deletions(-) create mode 100644 scripts/generate/templates/Arrays.js create mode 100644 scripts/generate/templates/Arrays.opts.js diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index 022520598..ac4dcba2c 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -1,5 +1,6 @@ // SPDX-License-Identifier: MIT // OpenZeppelin Contracts (last updated v5.0.0) (utils/Arrays.sol) +// This file was procedurally generated from scripts/generate/templates/Arrays.js. pragma solidity ^0.8.20; @@ -35,11 +36,20 @@ library Arrays { * @dev Variant of {sort} that sorts an array of bytes32 in increasing order. */ function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) { - return sort(array, _defaultComp); + sort(array, _defaultComp); + return array; } /** - * @dev Variant of {sort} that sorts an array of address following a provided comparator function. + * @dev Sort an array of address (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. */ function sort( address[] memory array, @@ -58,7 +68,15 @@ library Arrays { } /** - * @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function. + * @dev Sort an array of uint256 (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is `O(n · log(n))` in average and `O(n²)` in the worst case, with n the length of the + * array. Using it in view functions that are executed through `eth_call` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. */ function sort( uint256[] memory array, diff --git a/scripts/generate/run.js b/scripts/generate/run.js index 53589455a..902f6dd66 100755 --- a/scripts/generate/run.js +++ b/scripts/generate/run.js @@ -37,6 +37,7 @@ for (const [file, template] of Object.entries({ 'utils/structs/EnumerableMap.sol': './templates/EnumerableMap.js', 'utils/structs/Checkpoints.sol': './templates/Checkpoints.js', 'utils/StorageSlot.sol': './templates/StorageSlot.js', + 'utils/Arrays.sol': './templates/Arrays.js', })) { generateFromTemplate(file, template, './contracts/'); } diff --git a/scripts/generate/templates/Arrays.js b/scripts/generate/templates/Arrays.js new file mode 100644 index 000000000..bd0daa1f5 --- /dev/null +++ b/scripts/generate/templates/Arrays.js @@ -0,0 +1,386 @@ +const format = require('../format-lines'); +const { capitalize } = require('../../helpers'); +const { TYPES } = require('./Arrays.opts'); + +const header = `\ +pragma solidity ^0.8.20; + +import {StorageSlot} from "./StorageSlot.sol"; +import {Math} from "./math/Math.sol"; + +/** + * @dev Collection of functions related to array types. + */ +`; + +const sort = type => `\ + /** + * @dev Sort an array of ${type} (in memory) following the provided comparator function. + * + * This function does the sorting "in place", meaning that it overrides the input. The object is returned for + * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. + * + * NOTE: this function's cost is \`O(n · log(n))\` in average and \`O(n²)\` in the worst case, with n the length of the + * array. Using it in view functions that are executed through \`eth_call\` is safe, but one should be very careful + * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may + * consume more gas than is available in a block, leading to potential DoS. + */ + function sort( + ${type}[] memory array, + function(${type}, ${type}) pure returns (bool) comp + ) internal pure returns (${type}[] memory) { + ${ + type === 'bytes32' + ? '_quickSort(_begin(array), _end(array), comp);' + : 'sort(_castToBytes32Array(array), _castToBytes32Comp(comp));' + } + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of ${type} in increasing order. + */ + function sort(${type}[] memory array) internal pure returns (${type}[] memory) { + ${type === 'bytes32' ? 'sort(array, _defaultComp);' : 'sort(_castToBytes32Array(array), _defaultComp);'} + return array; + } +`; + +const quickSort = ` +/** + * @dev Performs a quick sort of a segment of memory. The segment sorted starts at \`begin\` (inclusive), and stops + * at end (exclusive). Sorting follows the \`comp\` comparator. + * + * Invariant: \`begin <= end\`. This is the case when initially called by {sort} and is preserved in subcalls. + * + * IMPORTANT: Memory locations between \`begin\` and \`end\` are not validated/zeroed. This function should + * be used only if the limits are within a memory array. + */ +function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure { + unchecked { + if (end - begin < 0x40) return; + + // Use first element as pivot + bytes32 pivot = _mload(begin); + // Position where the pivot should be at the end of the loop + uint256 pos = begin; + + for (uint256 it = begin + 0x20; it < end; it += 0x20) { + if (comp(_mload(it), pivot)) { + // If the value stored at the iterator's position comes before the pivot, we increment the + // position of the pivot and move the value there. + pos += 0x20; + _swap(pos, it); + } + } + + _swap(begin, pos); // Swap pivot into place + _quickSort(begin, pos, comp); // Sort the left side of the pivot + _quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot + } +} + +/** + * @dev Pointer to the memory location of the first element of \`array\`. + */ +function _begin(bytes32[] memory array) private pure returns (uint256 ptr) { + /// @solidity memory-safe-assembly + assembly { + ptr := add(array, 0x20) + } +} + +/** + * @dev Pointer to the memory location of the first memory word (32bytes) after \`array\`. This is the memory word + * that comes just after the last element of the array. + */ +function _end(bytes32[] memory array) private pure returns (uint256 ptr) { + unchecked { + return _begin(array) + array.length * 0x20; + } +} + +/** + * @dev Load memory word (as a bytes32) at location \`ptr\`. + */ +function _mload(uint256 ptr) private pure returns (bytes32 value) { + assembly { + value := mload(ptr) + } +} + +/** + * @dev Swaps the elements memory location \`ptr1\` and \`ptr2\`. + */ +function _swap(uint256 ptr1, uint256 ptr2) private pure { + assembly { + let value1 := mload(ptr1) + let value2 := mload(ptr2) + mstore(ptr1, value2) + mstore(ptr2, value1) + } +} +`; + +const defaultComparator = ` + /// @dev Comparator for sorting arrays in increasing order. + function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) { + return a < b; + } +`; + +const castArray = type => `\ + /// @dev Helper: low level cast ${type} memory array to uint256 memory array + function _castToBytes32Array(${type}[] memory input) private pure returns (bytes32[] memory output) { + assembly { + output := input + } + } +`; + +const castComparator = type => `\ + /// @dev Helper: low level cast ${type} comp function to bytes32 comp function + function _castToBytes32Comp( + function(${type}, ${type}) pure returns (bool) input + ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { + assembly { + output := input + } + } +`; + +const search = ` +/** + * @dev Searches a sorted \`array\` and returns the first index that contains + * a value greater or equal to \`element\`. If no such index exists (i.e. all + * values in the array are strictly less than \`element\`), the array length is + * returned. Time complexity O(log n). + * + * NOTE: The \`array\` is expected to be sorted in ascending order, and to + * contain no repeated elements. + * + * IMPORTANT: Deprecated. This implementation behaves as {lowerBound} but lacks + * support for repeated elements in the array. The {lowerBound} function should + * be used instead. + */ +function findUpperBound(uint256[] storage array, uint256 element) internal view returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeAccess(array, mid).value > element) { + high = mid; + } else { + low = mid + 1; + } + } + + // At this point \`low\` is the exclusive upper bound. We will return the inclusive upper bound. + if (low > 0 && unsafeAccess(array, low - 1).value == element) { + return low - 1; + } else { + return low; + } +} + +/** + * @dev Searches an \`array\` sorted in ascending order and returns the first + * index that contains a value greater or equal than \`element\`. If no such index + * exists (i.e. all values in the array are strictly less than \`element\`), the array + * length is returned. Time complexity O(log n). + * + * See C++'s https://en.cppreference.com/w/cpp/algorithm/lower_bound[lower_bound]. + */ +function lowerBound(uint256[] storage array, uint256 element) internal view returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeAccess(array, mid).value < element) { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } else { + high = mid; + } + } + + return low; +} + +/** + * @dev Searches an \`array\` sorted in ascending order and returns the first + * index that contains a value strictly greater than \`element\`. If no such index + * exists (i.e. all values in the array are strictly less than \`element\`), the array + * length is returned. Time complexity O(log n). + * + * See C++'s https://en.cppreference.com/w/cpp/algorithm/upper_bound[upper_bound]. + */ +function upperBound(uint256[] storage array, uint256 element) internal view returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeAccess(array, mid).value > element) { + high = mid; + } else { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } + } + + return low; +} + +/** + * @dev Same as {lowerBound}, but with an array in memory. + */ +function lowerBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeMemoryAccess(array, mid) < element) { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } else { + high = mid; + } + } + + return low; +} + +/** + * @dev Same as {upperBound}, but with an array in memory. + */ +function upperBoundMemory(uint256[] memory array, uint256 element) internal pure returns (uint256) { + uint256 low = 0; + uint256 high = array.length; + + if (high == 0) { + return 0; + } + + while (low < high) { + uint256 mid = Math.average(low, high); + + // Note that mid will always be strictly less than high (i.e. it will be a valid array index) + // because Math.average rounds towards zero (it does integer division with truncation). + if (unsafeMemoryAccess(array, mid) > element) { + high = mid; + } else { + // this cannot overflow because mid < high + unchecked { + low = mid + 1; + } + } + } + + return low; +} +`; + +const unsafeAccessStorage = type => ` +/** +* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. +* +* WARNING: Only use if you are certain \`pos\` is lower than the array length. +*/ +function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize( + type, +)}Slot storage) { + bytes32 slot; + // We use assembly to calculate the storage slot of the element at index \`pos\` of the dynamic array \`arr\` + // following https://docs.soliditylang.org/en/v0.8.20/internals/layout_in_storage.html#mappings-and-dynamic-arrays. + + /// @solidity memory-safe-assembly + assembly { + mstore(0, arr.slot) + slot := add(keccak256(0, 0x20), pos) + } + return slot.get${capitalize(type)}Slot(); +}`; + +const unsafeAccessMemory = type => ` +/** + * @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check. + * + * WARNING: Only use if you are certain \`pos\` is lower than the array length. + */ +function unsafeMemoryAccess(${type}[] memory arr, uint256 pos) internal pure returns (${type} res) { + assembly { + res := mload(add(add(arr, 0x20), mul(pos, 0x20))) + } +} +`; + +const unsafeSetLength = type => ` +/** + * @dev Helper to set the length of an dynamic array. Directly writing to \`.length\` is forbidden. + * + * WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased. + */ +function unsafeSetLength(${type}[] storage array, uint256 len) internal { + assembly { + sstore(array.slot, len) + } +}`; + +// GENERATE +module.exports = format( + header.trimEnd(), + 'library Arrays {', + 'using StorageSlot for bytes32;', + // sorting, comparator, helpers and internal + sort('bytes32'), + TYPES.filter(type => type !== 'bytes32').map(sort), + quickSort, + defaultComparator, + TYPES.filter(type => type !== 'bytes32').map(castArray), + TYPES.filter(type => type !== 'bytes32').map(castComparator), + // lookup + search, + // unsafe (direct) storage and memory access + TYPES.map(unsafeAccessStorage), + TYPES.map(unsafeAccessMemory), + TYPES.map(unsafeSetLength), + '}', +); diff --git a/scripts/generate/templates/Arrays.opts.js b/scripts/generate/templates/Arrays.opts.js new file mode 100644 index 000000000..67f329972 --- /dev/null +++ b/scripts/generate/templates/Arrays.opts.js @@ -0,0 +1,3 @@ +const TYPES = ['address', 'bytes32', 'uint256']; + +module.exports = { TYPES }; diff --git a/test/utils/Arrays.test.js b/test/utils/Arrays.test.js index 306851993..bcb385897 100644 --- a/test/utils/Arrays.test.js +++ b/test/utils/Arrays.test.js @@ -3,6 +3,8 @@ const { expect } = require('chai'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { generators } = require('../helpers/random'); +const { capitalize } = require('../../scripts/helpers'); +const { TYPES } = require('../../scripts/generate/templates/Arrays.opts'); // See https://en.cppreference.com/w/cpp/algorithm/lower_bound const lowerBound = (array, value) => { @@ -117,25 +119,12 @@ describe('Arrays', function () { } }); - for (const [type, { artifact, format }] of Object.entries({ - address: { - artifact: 'AddressArraysMock', - format: x => ethers.getAddress(ethers.toBeHex(x, 20)), - }, - bytes32: { - artifact: 'Bytes32ArraysMock', - format: x => ethers.toBeHex(x, 32), - }, - uint256: { - artifact: 'Uint256ArraysMock', - format: x => ethers.toBigInt(x), - }, - })) { + for (const type of TYPES) { const elements = Array.from({ length: 10 }, generators[type]); describe(type, function () { const fixture = async () => { - return { instance: await ethers.deployContract(artifact, [elements]) }; + return { instance: await ethers.deployContract(`${capitalize(type)}ArraysMock`, [elements]) }; }; beforeEach(async function () { @@ -222,7 +211,7 @@ describe('Arrays', function () { it('unsafeMemoryAccess loop around', async function () { for (let i = 251n; i < 256n; ++i) { - expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(format(elements.length)); + expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(BigInt(elements.length)); expect(await this.mock[fragment](elements, 2n ** i + 0n)).to.equal(elements[0]); expect(await this.mock[fragment](elements, 2n ** i + 1n)).to.equal(elements[1]); }