diff --git a/.changeset/dirty-cobras-smile.md b/.changeset/dirty-cobras-smile.md index f03e936b5..d71194cfc 100644 --- a/.changeset/dirty-cobras-smile.md +++ b/.changeset/dirty-cobras-smile.md @@ -2,4 +2,4 @@ 'openzeppelin-solidity': minor --- -`Arrays`: add a `sort` function. +`Arrays`: add a `sort` functions for `address[]`, `bytes32[]` and `uint256[]` memory arrays. diff --git a/contracts/mocks/ArraysMock.sol b/contracts/mocks/ArraysMock.sol index a2fbb6dea..0e39485fc 100644 --- a/contracts/mocks/ArraysMock.sol +++ b/contracts/mocks/ArraysMock.sol @@ -36,6 +36,18 @@ contract Uint256ArraysMock { function unsafeAccess(uint256 pos) external view returns (uint256) { return _array.unsafeAccess(pos).value; } + + function sort(uint256[] memory array) external pure returns (uint256[] memory) { + return array.sort(); + } + + function sortReverse(uint256[] memory array) external pure returns (uint256[] memory) { + return array.sort(_reverse); + } + + function _reverse(uint256 a, uint256 b) private pure returns (bool) { + return a > b; + } } contract AddressArraysMock { @@ -50,6 +62,18 @@ contract AddressArraysMock { function unsafeAccess(uint256 pos) external view returns (address) { return _array.unsafeAccess(pos).value; } + + function sort(address[] memory array) external pure returns (address[] memory) { + return array.sort(); + } + + function sortReverse(address[] memory array) external pure returns (address[] memory) { + return array.sort(_reverse); + } + + function _reverse(address a, address b) private pure returns (bool) { + return uint160(a) > uint160(b); + } } contract Bytes32ArraysMock { @@ -64,4 +88,16 @@ contract Bytes32ArraysMock { function unsafeAccess(uint256 pos) external view returns (bytes32) { return _array.unsafeAccess(pos).value; } + + function sort(bytes32[] memory array) external pure returns (bytes32[] memory) { + return array.sort(); + } + + function sortReverse(bytes32[] memory array) external pure returns (bytes32[] memory) { + return array.sort(_reverse); + } + + function _reverse(bytes32 a, bytes32 b) private pure returns (bool) { + return uint256(a) > uint256(b); + } } diff --git a/contracts/utils/Arrays.sol b/contracts/utils/Arrays.sol index f1a77f371..6b2d75f50 100644 --- a/contracts/utils/Arrays.sol +++ b/contracts/utils/Arrays.sol @@ -13,7 +13,7 @@ library Arrays { using StorageSlot for bytes32; /** - * @dev Sort an array (in memory) in increasing order. + * @dev Sort an array of bytes32 (in memory) following the provided comparator function. * * This function does the sorting "in place", meaning that it overrides the input. The object is returned for * convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array. @@ -23,55 +23,167 @@ library Arrays { * when executing this as part of a transaction. If the array being sorted is too large, the sort operation may * consume more gas than is available in a block, leading to potential DoS. */ - function sort(uint256[] memory array) internal pure returns (uint256[] memory) { - _quickSort(array, 0, array.length); + function sort( + bytes32[] memory array, + function(bytes32, bytes32) pure returns (bool) comp + ) internal pure returns (bytes32[] memory) { + _quickSort(_begin(array), _end(array), comp); return array; } /** - * @dev Performs a quick sort on an array in memory. The array is sorted in increasing order. - * - * Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in - * subcalls. + * @dev Variant of {sort} that sorts an array of bytes32 in increasing order. */ - function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure { + function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) { + return sort(array, _defaultComp); + } + + /** + * @dev Variant of {sort} that sorts an array of address following a provided comparator function. + */ + function sort( + address[] memory array, + function(address, address) pure returns (bool) comp + ) internal pure returns (address[] memory) { + sort(_castToBytes32Array(array), _castToBytes32Comp(comp)); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of address in increasing order. + */ + function sort(address[] memory array) internal pure returns (address[] memory) { + sort(_castToBytes32Array(array), _defaultComp); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function. + */ + function sort( + uint256[] memory array, + function(uint256, uint256) pure returns (bool) comp + ) internal pure returns (uint256[] memory) { + sort(_castToBytes32Array(array), _castToBytes32Comp(comp)); + return array; + } + + /** + * @dev Variant of {sort} that sorts an array of uint256 in increasing order. + */ + function sort(uint256[] memory array) internal pure returns (uint256[] memory) { + sort(_castToBytes32Array(array), _defaultComp); + return array; + } + + /** + * @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops + * at end (exclusive). Sorting follows the `comp` comparator. + * + * Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in subcalls. + * + * IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should + * be used only if the limits are within a memory array. + */ + function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure { unchecked { - // Can't overflow given `i <= j` - if (j - i < 2) return; + if (end - begin < 0x40) return; // Use first element as pivot - uint256 pivot = unsafeMemoryAccess(array, i); + bytes32 pivot = _mload(begin); // Position where the pivot should be at the end of the loop - uint256 index = i; + uint256 pos = begin; - for (uint256 k = i + 1; k < j; ++k) { - // Unsafe access is safe given `k < j <= array.length`. - if (unsafeMemoryAccess(array, k) < pivot) { - // If array[k] is smaller than the pivot, we increment the index and move array[k] there. - _swap(array, ++index, k); + for (uint256 it = begin + 0x20; it < end; it += 0x20) { + if (comp(_mload(it), pivot)) { + // If the value stored at the iterator's position comes before the pivot, we increment the + // position of the pivot and move the value there. + pos += 0x20; + _swap(pos, it); } } - // Swap pivot into place - _swap(array, i, index); - - _quickSort(array, i, index); // Sort the left side of the pivot - _quickSort(array, index + 1, j); // Sort the right side of the pivot + _swap(begin, pos); // Swap pivot into place + _quickSort(begin, pos, comp); // Sort the left side of the pivot + _quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot } } /** - * @dev Swaps the elements at positions `i` and `j` in the `arr` array. + * @dev Pointer to the memory location of the first element of `array`. */ - function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure { + function _begin(bytes32[] memory array) private pure returns (uint256 ptr) { + /// @solidity memory-safe-assembly assembly { - let start := add(arr, 0x20) // Pointer to the first element of the array - let pos_i := add(start, mul(i, 0x20)) - let pos_j := add(start, mul(j, 0x20)) - let val_i := mload(pos_i) - let val_j := mload(pos_j) - mstore(pos_i, val_j) - mstore(pos_j, val_i) + ptr := add(array, 0x20) + } + } + + /** + * @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word + * that comes just after the last element of the array. + */ + function _end(bytes32[] memory array) private pure returns (uint256 ptr) { + unchecked { + return _begin(array) + array.length * 0x20; + } + } + + /** + * @dev Load memory word (as a bytes32) at location `ptr`. + */ + function _mload(uint256 ptr) private pure returns (bytes32 value) { + assembly { + value := mload(ptr) + } + } + + /** + * @dev Swaps the elements memory location `ptr1` and `ptr2`. + */ + function _swap(uint256 ptr1, uint256 ptr2) private pure { + assembly { + let value1 := mload(ptr1) + let value2 := mload(ptr2) + mstore(ptr1, value2) + mstore(ptr2, value1) + } + } + + /// @dev Comparator for sorting arrays in increasing order. + function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) { + return a < b; + } + + /// @dev Helper: low level cast address memory array to uint256 memory array + function _castToBytes32Array(address[] memory input) private pure returns (bytes32[] memory output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast uint256 memory array to uint256 memory array + function _castToBytes32Array(uint256[] memory input) private pure returns (bytes32[] memory output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast address comp function to bytes32 comp function + function _castToBytes32Comp( + function(address, address) pure returns (bool) input + ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { + assembly { + output := input + } + } + + /// @dev Helper: low level cast uint256 comp function to bytes32 comp function + function _castToBytes32Comp( + function(uint256, uint256) pure returns (bool) input + ) private pure returns (function(bytes32, bytes32) pure returns (bool) output) { + assembly { + output := input } } diff --git a/scripts/helpers.js b/scripts/helpers.js index fb9aad4fc..2780c4edc 100644 --- a/scripts/helpers.js +++ b/scripts/helpers.js @@ -7,11 +7,7 @@ function range(start, stop = undefined, step = 1) { stop = start; start = 0; } - return start < stop - ? Array(Math.ceil((stop - start) / step)) - .fill() - .map((_, i) => start + i * step) - : []; + return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : []; } function unique(array, op = x => x) { @@ -19,9 +15,7 @@ function unique(array, op = x => x) { } function zip(...args) { - return Array(Math.max(...args.map(arg => arg.length))) - .fill(null) - .map((_, i) => args.map(arg => arg[i])); + return Array.from({ length: Math.max(...args.map(arg => arg.length)) }, (_, i) => args.map(arg => arg[i])); } function capitalize(str) { diff --git a/test/finance/VestingWallet.test.js b/test/finance/VestingWallet.test.js index 8ea87b035..318013575 100644 --- a/test/finance/VestingWallet.test.js +++ b/test/finance/VestingWallet.test.js @@ -55,9 +55,7 @@ async function fixture() { }, }; - const schedule = Array(64) - .fill() - .map((_, i) => (BigInt(i) * duration) / 60n + start); + const schedule = Array.from({ length: 64 }, (_, i) => (BigInt(i) * duration) / 60n + start); const vestingFn = timestamp => min(amount, (amount * (timestamp - start)) / duration); diff --git a/test/helpers/iterate.js b/test/helpers/iterate.js index 79d1c8c83..220f2c368 100644 --- a/test/helpers/iterate.js +++ b/test/helpers/iterate.js @@ -5,9 +5,7 @@ const mapValues = (obj, fn) => Object.fromEntries(Object.entries(obj).map(([k, v const product = (...arrays) => arrays.reduce((a, b) => a.flatMap(ai => b.map(bi => [...ai, bi])), [[]]); const unique = (...array) => array.filter((obj, i) => array.indexOf(obj) === i); const zip = (...args) => - Array(Math.max(...args.map(array => array.length))) - .fill() - .map((_, i) => args.map(array => array[i])); + Array.from({ length: Math.max(...args.map(array => array.length)) }, (_, i) => args.map(array => array[i])); module.exports = { mapValues, diff --git a/test/helpers/random.js b/test/helpers/random.js index d47fba544..48b97768b 100644 --- a/test/helpers/random.js +++ b/test/helpers/random.js @@ -1,7 +1,5 @@ const { ethers } = require('hardhat'); -const randomArray = (generator, arrayLength = 3) => Array(arrayLength).fill().map(generator); - const generators = { address: () => ethers.Wallet.createRandom().address, bytes32: () => ethers.hexlify(ethers.randomBytes(32)), @@ -15,6 +13,5 @@ generators.uint256.zero = 0n; generators.hexBytes.zero = '0x'; module.exports = { - randomArray, generators, }; diff --git a/test/utils/Arrays.test.js b/test/utils/Arrays.test.js index ffe5d5a22..8edf75fa7 100644 --- a/test/utils/Arrays.test.js +++ b/test/utils/Arrays.test.js @@ -2,7 +2,7 @@ const { ethers } = require('hardhat'); const { expect } = require('chai'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); -const { randomArray, generators } = require('../helpers/random'); +const { generators } = require('../helpers/random'); // See https://en.cppreference.com/w/cpp/algorithm/lower_bound const lowerBound = (array, value) => { @@ -16,9 +16,7 @@ const upperBound = (array, value) => { return i == -1 ? array.length : i; }; -// By default, js "sort" cast to string and then sort in alphabetical order. Use this to sort numbers. -const compareNumbers = (a, b) => (a > b ? 1 : a < b ? -1 : 0); - +const bigintSign = x => (x > 0n ? 1 : x < 0n ? -1 : 0); const hasDuplicates = array => array.some((v, i) => array.indexOf(v) != i); describe('Arrays', function () { @@ -30,42 +28,6 @@ describe('Arrays', function () { Object.assign(this, await loadFixture(fixture)); }); - describe('sort', function () { - for (const length of [0, 1, 2, 8, 32, 128]) { - it(`sort array of length ${length}`, async function () { - this.elements = randomArray(generators.uint256, length); - this.expected = Array.from(this.elements).sort(compareNumbers); - }); - - if (length > 1) { - it(`sort array of length ${length} (identical elements)`, async function () { - this.elements = Array(length).fill(generators.uint256()); - this.expected = this.elements; - }); - - it(`sort array of length ${length} (already sorted)`, async function () { - this.elements = randomArray(generators.uint256, length).sort(compareNumbers); - this.expected = this.elements; - }); - - it(`sort array of length ${length} (sorted in reverse order)`, async function () { - this.elements = randomArray(generators.uint256, length).sort(compareNumbers).reverse(); - this.expected = Array.from(this.elements).reverse(); - }); - - it(`sort array of length ${length} (almost sorted)`, async function () { - this.elements = randomArray(generators.uint256, length).sort(compareNumbers); - this.expected = Array.from(this.elements); - // rotate (move the last element to the front) for an almost sorted effect - this.elements.unshift(this.elements.pop()); - }); - } - } - afterEach(async function () { - expect(await this.mock.$sort(this.elements)).to.deep.equal(this.expected); - }); - }); - describe('search', function () { for (const [title, { array, tests }] of Object.entries({ 'Even number of elements': { @@ -154,22 +116,78 @@ describe('Arrays', function () { } }); - describe('unsafeAccess', function () { - for (const [type, { artifact, elements }] of Object.entries({ - address: { artifact: 'AddressArraysMock', elements: randomArray(generators.address, 10) }, - bytes32: { artifact: 'Bytes32ArraysMock', elements: randomArray(generators.bytes32, 10) }, - uint256: { artifact: 'Uint256ArraysMock', elements: randomArray(generators.uint256, 10) }, - })) { - describe(type, function () { - describe('storage', function () { - const fixture = async () => { - return { instance: await ethers.deployContract(artifact, [elements]) }; - }; + for (const [type, { artifact, elements, comp }] of Object.entries({ + address: { + artifact: 'AddressArraysMock', + elements: Array.from({ length: 10 }, generators.address), + comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)), + }, + bytes32: { + artifact: 'Bytes32ArraysMock', + elements: Array.from({ length: 10 }, generators.bytes32), + comp: (a, b) => bigintSign(ethers.toBigInt(a) - ethers.toBigInt(b)), + }, + uint256: { + artifact: 'Uint256ArraysMock', + elements: Array.from({ length: 10 }, generators.uint256), + comp: (a, b) => bigintSign(a - b), + }, + })) { + describe(type, function () { + const fixture = async () => { + return { instance: await ethers.deployContract(artifact, [elements]) }; + }; - beforeEach(async function () { - Object.assign(this, await loadFixture(fixture)); + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + describe('sort', function () { + for (const length of [0, 1, 2, 8, 32, 128]) { + describe(`${type}[] of length ${length}`, function () { + beforeEach(async function () { + this.elements = Array.from({ length }, generators[type]); + }); + + afterEach(async function () { + const expected = Array.from(this.elements).sort(comp); + const reversed = Array.from(expected).reverse(); + expect(await this.instance.sort(this.elements)).to.deep.equal(expected); + expect(await this.instance.sortReverse(this.elements)).to.deep.equal(reversed); + }); + + it('sort array', async function () { + // nothing to do here, beforeEach and afterEach already take care of everything. + }); + + if (length > 1) { + it('sort array for identical elements', async function () { + // duplicate the first value to all elements + this.elements.fill(this.elements.at(0)); + }); + + it('sort already sorted array', async function () { + // pre-sort the elements + this.elements.sort(comp); + }); + + it('sort reversed array', async function () { + // pre-sort in reverse order + this.elements.sort(comp).reverse(); + }); + + it('sort almost sorted array', async function () { + // pre-sort + rotate (move the last element to the front) for an almost sorted effect + this.elements.sort(comp); + this.elements.unshift(this.elements.pop()); + }); + } }); + } + }); + describe('unsafeAccess', function () { + describe('storage', function () { for (const i in elements) { it(`unsafeAccess within bounds #${i}`, async function () { expect(await this.instance.unsafeAccess(i)).to.equal(elements[i]); @@ -195,6 +213,6 @@ describe('Arrays', function () { }); }); }); - } - }); + }); + } }); diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index 2762fcc57..b75d3b588 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -5,7 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); const { Rounding } = require('../../helpers/enums'); const { min, max } = require('../../helpers/math'); -const { randomArray, generators } = require('../../helpers/random'); +const { generators } = require('../../helpers/random'); const RoundingDown = [Rounding.Floor, Rounding.Trunc]; const RoundingUp = [Rounding.Ceil, Rounding.Expand]; @@ -337,7 +337,7 @@ describe('Math', function () { }); if (p != 0) { - for (const value of randomArray(generators.uint256, 16)) { + for (const value of Array.from({ length: 16 }, generators.uint256)) { const isInversible = factors.every(f => value % f); it(`trying to inverse ${value}`, async function () { const result = await this.mock.$invMod(value, p); diff --git a/test/utils/structs/DoubleEndedQueue.test.js b/test/utils/structs/DoubleEndedQueue.test.js index 3547072ee..3615dfbf4 100644 --- a/test/utils/structs/DoubleEndedQueue.test.js +++ b/test/utils/structs/DoubleEndedQueue.test.js @@ -8,13 +8,7 @@ async function fixture() { /** Rebuild the content of the deque as a JS array. */ const getContent = () => - mock.$length(0).then(length => - Promise.all( - Array(Number(length)) - .fill() - .map((_, i) => mock.$at(0, i)), - ), - ); + mock.$length(0).then(length => Promise.all(Array.from({ length: Number(length) }, (_, i) => mock.$at(0, i)))); return { mock, getContent }; } diff --git a/test/utils/structs/EnumerableMap.test.js b/test/utils/structs/EnumerableMap.test.js index 717495c26..5362e873a 100644 --- a/test/utils/structs/EnumerableMap.test.js +++ b/test/utils/structs/EnumerableMap.test.js @@ -2,7 +2,7 @@ const { ethers } = require('hardhat'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { mapValues } = require('../../helpers/iterate'); -const { randomArray, generators } = require('../../helpers/random'); +const { generators } = require('../../helpers/random'); const { TYPES, formatType } = require('../../../scripts/generate/templates/EnumerableMap.opts'); const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior'); @@ -17,8 +17,8 @@ async function fixture() { name, { keyType, - keys: randomArray(generators[keyType]), - values: randomArray(generators[valueType]), + keys: Array.from({ length: 3 }, generators[keyType]), + values: Array.from({ length: 3 }, generators[valueType]), zeroValue: generators[valueType].zero, methods: mapValues( { diff --git a/test/utils/structs/EnumerableSet.test.js b/test/utils/structs/EnumerableSet.test.js index db6c5a453..66d666058 100644 --- a/test/utils/structs/EnumerableSet.test.js +++ b/test/utils/structs/EnumerableSet.test.js @@ -2,7 +2,7 @@ const { ethers } = require('hardhat'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { mapValues } = require('../../helpers/iterate'); -const { randomArray, generators } = require('../../helpers/random'); +const { generators } = require('../../helpers/random'); const { TYPES } = require('../../../scripts/generate/templates/EnumerableSet.opts'); const { shouldBehaveLikeSet } = require('./EnumerableSet.behavior'); @@ -23,7 +23,7 @@ async function fixture() { TYPES.map(({ name, type }) => [ type, { - values: randomArray(generators[type]), + values: Array.from({ length: 3 }, generators[type]), methods: getMethods(mock, { add: `$add(uint256,${type})`, remove: `$remove(uint256,${type})`,