Add variants of Array.sort for address[] and bytes32[] (#4883)
Co-authored-by: Ernesto García <ernestognw@gmail.com>
This commit is contained in:
@ -13,7 +13,7 @@ library Arrays {
|
||||
using StorageSlot for bytes32;
|
||||
|
||||
/**
|
||||
* @dev Sort an array (in memory) in increasing order.
|
||||
* @dev Sort an array of bytes32 (in memory) following the provided comparator function.
|
||||
*
|
||||
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
|
||||
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
|
||||
@ -23,55 +23,167 @@ library Arrays {
|
||||
* when executing this as part of a transaction. If the array being sorted is too large, the sort operation may
|
||||
* consume more gas than is available in a block, leading to potential DoS.
|
||||
*/
|
||||
function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
|
||||
_quickSort(array, 0, array.length);
|
||||
function sort(
|
||||
bytes32[] memory array,
|
||||
function(bytes32, bytes32) pure returns (bool) comp
|
||||
) internal pure returns (bytes32[] memory) {
|
||||
_quickSort(_begin(array), _end(array), comp);
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Performs a quick sort on an array in memory. The array is sorted in increasing order.
|
||||
*
|
||||
* Invariant: `i <= j <= array.length`. This is the case when initially called by {sort} and is preserved in
|
||||
* subcalls.
|
||||
* @dev Variant of {sort} that sorts an array of bytes32 in increasing order.
|
||||
*/
|
||||
function _quickSort(uint256[] memory array, uint256 i, uint256 j) private pure {
|
||||
function sort(bytes32[] memory array) internal pure returns (bytes32[] memory) {
|
||||
return sort(array, _defaultComp);
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Variant of {sort} that sorts an array of address following a provided comparator function.
|
||||
*/
|
||||
function sort(
|
||||
address[] memory array,
|
||||
function(address, address) pure returns (bool) comp
|
||||
) internal pure returns (address[] memory) {
|
||||
sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Variant of {sort} that sorts an array of address in increasing order.
|
||||
*/
|
||||
function sort(address[] memory array) internal pure returns (address[] memory) {
|
||||
sort(_castToBytes32Array(array), _defaultComp);
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Variant of {sort} that sorts an array of uint256 following a provided comparator function.
|
||||
*/
|
||||
function sort(
|
||||
uint256[] memory array,
|
||||
function(uint256, uint256) pure returns (bool) comp
|
||||
) internal pure returns (uint256[] memory) {
|
||||
sort(_castToBytes32Array(array), _castToBytes32Comp(comp));
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Variant of {sort} that sorts an array of uint256 in increasing order.
|
||||
*/
|
||||
function sort(uint256[] memory array) internal pure returns (uint256[] memory) {
|
||||
sort(_castToBytes32Array(array), _defaultComp);
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Performs a quick sort of a segment of memory. The segment sorted starts at `begin` (inclusive), and stops
|
||||
* at end (exclusive). Sorting follows the `comp` comparator.
|
||||
*
|
||||
* Invariant: `begin <= end`. This is the case when initially called by {sort} and is preserved in subcalls.
|
||||
*
|
||||
* IMPORTANT: Memory locations between `begin` and `end` are not validated/zeroed. This function should
|
||||
* be used only if the limits are within a memory array.
|
||||
*/
|
||||
function _quickSort(uint256 begin, uint256 end, function(bytes32, bytes32) pure returns (bool) comp) private pure {
|
||||
unchecked {
|
||||
// Can't overflow given `i <= j`
|
||||
if (j - i < 2) return;
|
||||
if (end - begin < 0x40) return;
|
||||
|
||||
// Use first element as pivot
|
||||
uint256 pivot = unsafeMemoryAccess(array, i);
|
||||
bytes32 pivot = _mload(begin);
|
||||
// Position where the pivot should be at the end of the loop
|
||||
uint256 index = i;
|
||||
uint256 pos = begin;
|
||||
|
||||
for (uint256 k = i + 1; k < j; ++k) {
|
||||
// Unsafe access is safe given `k < j <= array.length`.
|
||||
if (unsafeMemoryAccess(array, k) < pivot) {
|
||||
// If array[k] is smaller than the pivot, we increment the index and move array[k] there.
|
||||
_swap(array, ++index, k);
|
||||
for (uint256 it = begin + 0x20; it < end; it += 0x20) {
|
||||
if (comp(_mload(it), pivot)) {
|
||||
// If the value stored at the iterator's position comes before the pivot, we increment the
|
||||
// position of the pivot and move the value there.
|
||||
pos += 0x20;
|
||||
_swap(pos, it);
|
||||
}
|
||||
}
|
||||
|
||||
// Swap pivot into place
|
||||
_swap(array, i, index);
|
||||
|
||||
_quickSort(array, i, index); // Sort the left side of the pivot
|
||||
_quickSort(array, index + 1, j); // Sort the right side of the pivot
|
||||
_swap(begin, pos); // Swap pivot into place
|
||||
_quickSort(begin, pos, comp); // Sort the left side of the pivot
|
||||
_quickSort(pos + 0x20, end, comp); // Sort the right side of the pivot
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Swaps the elements at positions `i` and `j` in the `arr` array.
|
||||
* @dev Pointer to the memory location of the first element of `array`.
|
||||
*/
|
||||
function _swap(uint256[] memory arr, uint256 i, uint256 j) private pure {
|
||||
function _begin(bytes32[] memory array) private pure returns (uint256 ptr) {
|
||||
/// @solidity memory-safe-assembly
|
||||
assembly {
|
||||
let start := add(arr, 0x20) // Pointer to the first element of the array
|
||||
let pos_i := add(start, mul(i, 0x20))
|
||||
let pos_j := add(start, mul(j, 0x20))
|
||||
let val_i := mload(pos_i)
|
||||
let val_j := mload(pos_j)
|
||||
mstore(pos_i, val_j)
|
||||
mstore(pos_j, val_i)
|
||||
ptr := add(array, 0x20)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Pointer to the memory location of the first memory word (32bytes) after `array`. This is the memory word
|
||||
* that comes just after the last element of the array.
|
||||
*/
|
||||
function _end(bytes32[] memory array) private pure returns (uint256 ptr) {
|
||||
unchecked {
|
||||
return _begin(array) + array.length * 0x20;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Load memory word (as a bytes32) at location `ptr`.
|
||||
*/
|
||||
function _mload(uint256 ptr) private pure returns (bytes32 value) {
|
||||
assembly {
|
||||
value := mload(ptr)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Swaps the elements memory location `ptr1` and `ptr2`.
|
||||
*/
|
||||
function _swap(uint256 ptr1, uint256 ptr2) private pure {
|
||||
assembly {
|
||||
let value1 := mload(ptr1)
|
||||
let value2 := mload(ptr2)
|
||||
mstore(ptr1, value2)
|
||||
mstore(ptr2, value1)
|
||||
}
|
||||
}
|
||||
|
||||
/// @dev Comparator for sorting arrays in increasing order.
|
||||
function _defaultComp(bytes32 a, bytes32 b) private pure returns (bool) {
|
||||
return a < b;
|
||||
}
|
||||
|
||||
/// @dev Helper: low level cast address memory array to uint256 memory array
|
||||
function _castToBytes32Array(address[] memory input) private pure returns (bytes32[] memory output) {
|
||||
assembly {
|
||||
output := input
|
||||
}
|
||||
}
|
||||
|
||||
/// @dev Helper: low level cast uint256 memory array to uint256 memory array
|
||||
function _castToBytes32Array(uint256[] memory input) private pure returns (bytes32[] memory output) {
|
||||
assembly {
|
||||
output := input
|
||||
}
|
||||
}
|
||||
|
||||
/// @dev Helper: low level cast address comp function to bytes32 comp function
|
||||
function _castToBytes32Comp(
|
||||
function(address, address) pure returns (bool) input
|
||||
) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
|
||||
assembly {
|
||||
output := input
|
||||
}
|
||||
}
|
||||
|
||||
/// @dev Helper: low level cast uint256 comp function to bytes32 comp function
|
||||
function _castToBytes32Comp(
|
||||
function(uint256, uint256) pure returns (bool) input
|
||||
) private pure returns (function(bytes32, bytes32) pure returns (bool) output) {
|
||||
assembly {
|
||||
output := input
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user