Add support for more types in Arrays.sol (#5568)
This commit is contained in:
5
.changeset/rare-shirts-unite.md
Normal file
5
.changeset/rare-shirts-unite.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
'openzeppelin-solidity': minor
|
||||||
|
---
|
||||||
|
|
||||||
|
`Arrays`: Add `unsafeAccess`, `unsafeMemoryAccess` and `unsafeSetLength` for `bytes[]` and `string[]`.
|
||||||
@ -125,3 +125,47 @@ contract Bytes32ArraysMock {
|
|||||||
return _array.length;
|
return _array.length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
contract BytesArraysMock {
|
||||||
|
using Arrays for bytes[];
|
||||||
|
|
||||||
|
bytes[] private _array;
|
||||||
|
|
||||||
|
constructor(bytes[] memory array) {
|
||||||
|
_array = array;
|
||||||
|
}
|
||||||
|
|
||||||
|
function unsafeAccess(uint256 pos) external view returns (bytes memory) {
|
||||||
|
return _array.unsafeAccess(pos).value;
|
||||||
|
}
|
||||||
|
|
||||||
|
function unsafeSetLength(uint256 newLength) external {
|
||||||
|
_array.unsafeSetLength(newLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
function length() external view returns (uint256) {
|
||||||
|
return _array.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contract StringArraysMock {
|
||||||
|
using Arrays for string[];
|
||||||
|
|
||||||
|
string[] private _array;
|
||||||
|
|
||||||
|
constructor(string[] memory array) {
|
||||||
|
_array = array;
|
||||||
|
}
|
||||||
|
|
||||||
|
function unsafeAccess(uint256 pos) external view returns (string memory) {
|
||||||
|
return _array.unsafeAccess(pos).value;
|
||||||
|
}
|
||||||
|
|
||||||
|
function unsafeSetLength(uint256 newLength) external {
|
||||||
|
_array.unsafeSetLength(newLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
function length() external view returns (uint256) {
|
||||||
|
return _array.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -414,6 +414,32 @@ library Arrays {
|
|||||||
return slot.deriveArray().offset(pos).getUint256Slot();
|
return slot.deriveArray().offset(pos).getUint256Slot();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
||||||
|
*
|
||||||
|
* WARNING: Only use if you are certain `pos` is lower than the array length.
|
||||||
|
*/
|
||||||
|
function unsafeAccess(bytes[] storage arr, uint256 pos) internal pure returns (StorageSlot.BytesSlot storage) {
|
||||||
|
bytes32 slot;
|
||||||
|
assembly ("memory-safe") {
|
||||||
|
slot := arr.slot
|
||||||
|
}
|
||||||
|
return slot.deriveArray().offset(pos).getBytesSlot();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
||||||
|
*
|
||||||
|
* WARNING: Only use if you are certain `pos` is lower than the array length.
|
||||||
|
*/
|
||||||
|
function unsafeAccess(string[] storage arr, uint256 pos) internal pure returns (StorageSlot.StringSlot storage) {
|
||||||
|
bytes32 slot;
|
||||||
|
assembly ("memory-safe") {
|
||||||
|
slot := arr.slot
|
||||||
|
}
|
||||||
|
return slot.deriveArray().offset(pos).getStringSlot();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
||||||
*
|
*
|
||||||
@ -447,6 +473,28 @@ library Arrays {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
||||||
|
*
|
||||||
|
* WARNING: Only use if you are certain `pos` is lower than the array length.
|
||||||
|
*/
|
||||||
|
function unsafeMemoryAccess(bytes[] memory arr, uint256 pos) internal pure returns (bytes memory res) {
|
||||||
|
assembly {
|
||||||
|
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Access an array in an "unsafe" way. Skips solidity "index-out-of-range" check.
|
||||||
|
*
|
||||||
|
* WARNING: Only use if you are certain `pos` is lower than the array length.
|
||||||
|
*/
|
||||||
|
function unsafeMemoryAccess(string[] memory arr, uint256 pos) internal pure returns (string memory res) {
|
||||||
|
assembly {
|
||||||
|
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
|
* @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
|
||||||
*
|
*
|
||||||
@ -479,4 +527,26 @@ library Arrays {
|
|||||||
sstore(array.slot, len)
|
sstore(array.slot, len)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
|
||||||
|
*
|
||||||
|
* WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
|
||||||
|
*/
|
||||||
|
function unsafeSetLength(bytes[] storage array, uint256 len) internal {
|
||||||
|
assembly ("memory-safe") {
|
||||||
|
sstore(array.slot, len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Helper to set the length of a dynamic array. Directly writing to `.length` is forbidden.
|
||||||
|
*
|
||||||
|
* WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
|
||||||
|
*/
|
||||||
|
function unsafeSetLength(string[] storage array, uint256 len) internal {
|
||||||
|
assembly ("memory-safe") {
|
||||||
|
sstore(array.slot, len)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import {Math} from "./math/Math.sol";
|
|||||||
|
|
||||||
const sort = type => `\
|
const sort = type => `\
|
||||||
/**
|
/**
|
||||||
* @dev Sort an array of ${type} (in memory) following the provided comparator function.
|
* @dev Sort an array of ${type.name} (in memory) following the provided comparator function.
|
||||||
*
|
*
|
||||||
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
|
* This function does the sorting "in place", meaning that it overrides the input. The object is returned for
|
||||||
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
|
* convenience, but that returned value can be discarded safely if the caller has a memory pointer to the array.
|
||||||
@ -30,11 +30,11 @@ const sort = type => `\
|
|||||||
* IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an unsafe way.
|
* IMPORTANT: Consider memory side-effects when using custom comparator functions that access memory in an unsafe way.
|
||||||
*/
|
*/
|
||||||
function sort(
|
function sort(
|
||||||
${type}[] memory array,
|
${type.name}[] memory array,
|
||||||
function(${type}, ${type}) pure returns (bool) comp
|
function(${type.name}, ${type.name}) pure returns (bool) comp
|
||||||
) internal pure returns (${type}[] memory) {
|
) internal pure returns (${type.name}[] memory) {
|
||||||
${
|
${
|
||||||
type === 'uint256'
|
type.name === 'uint256'
|
||||||
? '_quickSort(_begin(array), _end(array), comp);'
|
? '_quickSort(_begin(array), _end(array), comp);'
|
||||||
: 'sort(_castToUint256Array(array), _castToUint256Comp(comp));'
|
: 'sort(_castToUint256Array(array), _castToUint256Comp(comp));'
|
||||||
}
|
}
|
||||||
@ -42,10 +42,10 @@ function sort(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @dev Variant of {sort} that sorts an array of ${type} in increasing order.
|
* @dev Variant of {sort} that sorts an array of ${type.name} in increasing order.
|
||||||
*/
|
*/
|
||||||
function sort(${type}[] memory array) internal pure returns (${type}[] memory) {
|
function sort(${type.name}[] memory array) internal pure returns (${type.name}[] memory) {
|
||||||
${type === 'uint256' ? 'sort(array, Comparators.lt);' : 'sort(_castToUint256Array(array), Comparators.lt);'}
|
${type.name === 'uint256' ? 'sort(array, Comparators.lt);' : 'sort(_castToUint256Array(array), Comparators.lt);'}
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
`;
|
`;
|
||||||
@ -126,8 +126,8 @@ function _swap(uint256 ptr1, uint256 ptr2) private pure {
|
|||||||
`;
|
`;
|
||||||
|
|
||||||
const castArray = type => `\
|
const castArray = type => `\
|
||||||
/// @dev Helper: low level cast ${type} memory array to uint256 memory array
|
/// @dev Helper: low level cast ${type.name} memory array to uint256 memory array
|
||||||
function _castToUint256Array(${type}[] memory input) private pure returns (uint256[] memory output) {
|
function _castToUint256Array(${type.name}[] memory input) private pure returns (uint256[] memory output) {
|
||||||
assembly {
|
assembly {
|
||||||
output := input
|
output := input
|
||||||
}
|
}
|
||||||
@ -135,9 +135,9 @@ function _castToUint256Array(${type}[] memory input) private pure returns (uint2
|
|||||||
`;
|
`;
|
||||||
|
|
||||||
const castComparator = type => `\
|
const castComparator = type => `\
|
||||||
/// @dev Helper: low level cast ${type} comp function to uint256 comp function
|
/// @dev Helper: low level cast ${type.name} comp function to uint256 comp function
|
||||||
function _castToUint256Comp(
|
function _castToUint256Comp(
|
||||||
function(${type}, ${type}) pure returns (bool) input
|
function(${type.name}, ${type.name}) pure returns (bool) input
|
||||||
) private pure returns (function(uint256, uint256) pure returns (bool) output) {
|
) private pure returns (function(uint256, uint256) pure returns (bool) output) {
|
||||||
assembly {
|
assembly {
|
||||||
output := input
|
output := input
|
||||||
@ -320,14 +320,14 @@ const unsafeAccessStorage = type => `\
|
|||||||
*
|
*
|
||||||
* WARNING: Only use if you are certain \`pos\` is lower than the array length.
|
* WARNING: Only use if you are certain \`pos\` is lower than the array length.
|
||||||
*/
|
*/
|
||||||
function unsafeAccess(${type}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize(
|
function unsafeAccess(${type.name}[] storage arr, uint256 pos) internal pure returns (StorageSlot.${capitalize(
|
||||||
type,
|
type.name,
|
||||||
)}Slot storage) {
|
)}Slot storage) {
|
||||||
bytes32 slot;
|
bytes32 slot;
|
||||||
assembly ("memory-safe") {
|
assembly ("memory-safe") {
|
||||||
slot := arr.slot
|
slot := arr.slot
|
||||||
}
|
}
|
||||||
return slot.deriveArray().offset(pos).get${capitalize(type)}Slot();
|
return slot.deriveArray().offset(pos).get${capitalize(type.name)}Slot();
|
||||||
}
|
}
|
||||||
`;
|
`;
|
||||||
|
|
||||||
@ -337,7 +337,9 @@ const unsafeAccessMemory = type => `\
|
|||||||
*
|
*
|
||||||
* WARNING: Only use if you are certain \`pos\` is lower than the array length.
|
* WARNING: Only use if you are certain \`pos\` is lower than the array length.
|
||||||
*/
|
*/
|
||||||
function unsafeMemoryAccess(${type}[] memory arr, uint256 pos) internal pure returns (${type} res) {
|
function unsafeMemoryAccess(${type.name}[] memory arr, uint256 pos) internal pure returns (${type.name}${
|
||||||
|
type.isValueType ? '' : ' memory'
|
||||||
|
} res) {
|
||||||
assembly {
|
assembly {
|
||||||
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
|
res := mload(add(add(arr, 0x20), mul(pos, 0x20)))
|
||||||
}
|
}
|
||||||
@ -350,7 +352,7 @@ const unsafeSetLength = type => `\
|
|||||||
*
|
*
|
||||||
* WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
|
* WARNING: this does not clear elements if length is reduced, of initialize elements if length is increased.
|
||||||
*/
|
*/
|
||||||
function unsafeSetLength(${type}[] storage array, uint256 len) internal {
|
function unsafeSetLength(${type.name}[] storage array, uint256 len) internal {
|
||||||
assembly ("memory-safe") {
|
assembly ("memory-safe") {
|
||||||
sstore(array.slot, len)
|
sstore(array.slot, len)
|
||||||
}
|
}
|
||||||
@ -367,11 +369,11 @@ module.exports = format(
|
|||||||
'using StorageSlot for bytes32;',
|
'using StorageSlot for bytes32;',
|
||||||
'',
|
'',
|
||||||
// sorting, comparator, helpers and internal
|
// sorting, comparator, helpers and internal
|
||||||
sort('uint256'),
|
sort({ name: 'uint256' }),
|
||||||
TYPES.filter(type => type !== 'uint256').map(sort),
|
TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(sort),
|
||||||
quickSort,
|
quickSort,
|
||||||
TYPES.filter(type => type !== 'uint256').map(castArray),
|
TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(castArray),
|
||||||
TYPES.filter(type => type !== 'uint256').map(castComparator),
|
TYPES.filter(type => type.isValueType && type.name !== 'uint256').map(castComparator),
|
||||||
// lookup
|
// lookup
|
||||||
search,
|
search,
|
||||||
// unsafe (direct) storage and memory access
|
// unsafe (direct) storage and memory access
|
||||||
|
|||||||
@ -1,3 +1,9 @@
|
|||||||
const TYPES = ['address', 'bytes32', 'uint256'];
|
const TYPES = [
|
||||||
|
{ name: 'address', isValueType: true },
|
||||||
|
{ name: 'bytes32', isValueType: true },
|
||||||
|
{ name: 'uint256', isValueType: true },
|
||||||
|
{ name: 'bytes', isValueType: false },
|
||||||
|
{ name: 'string', isValueType: false },
|
||||||
|
];
|
||||||
|
|
||||||
module.exports = { TYPES };
|
module.exports = { TYPES };
|
||||||
|
|||||||
@ -5,14 +5,19 @@ const generators = {
|
|||||||
bytes32: () => ethers.hexlify(ethers.randomBytes(32)),
|
bytes32: () => ethers.hexlify(ethers.randomBytes(32)),
|
||||||
uint256: () => ethers.toBigInt(ethers.randomBytes(32)),
|
uint256: () => ethers.toBigInt(ethers.randomBytes(32)),
|
||||||
int256: () => ethers.toBigInt(ethers.randomBytes(32)) + ethers.MinInt256,
|
int256: () => ethers.toBigInt(ethers.randomBytes(32)) + ethers.MinInt256,
|
||||||
hexBytes: length => ethers.hexlify(ethers.randomBytes(length)),
|
bytes: (length = 32) => ethers.hexlify(ethers.randomBytes(length)),
|
||||||
|
string: () => ethers.uuidV4(ethers.randomBytes(32)),
|
||||||
};
|
};
|
||||||
|
|
||||||
generators.address.zero = ethers.ZeroAddress;
|
generators.address.zero = ethers.ZeroAddress;
|
||||||
generators.bytes32.zero = ethers.ZeroHash;
|
generators.bytes32.zero = ethers.ZeroHash;
|
||||||
generators.uint256.zero = 0n;
|
generators.uint256.zero = 0n;
|
||||||
generators.int256.zero = 0n;
|
generators.int256.zero = 0n;
|
||||||
generators.hexBytes.zero = '0x';
|
generators.bytes.zero = '0x';
|
||||||
|
generators.string.zero = '';
|
||||||
|
|
||||||
|
// alias hexBytes -> bytes
|
||||||
|
generators.hexBytes = generators.bytes;
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
generators,
|
generators,
|
||||||
|
|||||||
@ -119,23 +119,24 @@ describe('Arrays', function () {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (const type of TYPES) {
|
for (const { name, isValueType } of TYPES) {
|
||||||
const elements = Array.from({ length: 10 }, generators[type]);
|
const elements = Array.from({ length: 10 }, generators[name]);
|
||||||
|
|
||||||
describe(type, function () {
|
describe(name, function () {
|
||||||
const fixture = async () => {
|
const fixture = async () => {
|
||||||
return { instance: await ethers.deployContract(`${capitalize(type)}ArraysMock`, [elements]) };
|
return { instance: await ethers.deployContract(`${capitalize(name)}ArraysMock`, [elements]) };
|
||||||
};
|
};
|
||||||
|
|
||||||
beforeEach(async function () {
|
beforeEach(async function () {
|
||||||
Object.assign(this, await loadFixture(fixture));
|
Object.assign(this, await loadFixture(fixture));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (isValueType) {
|
||||||
describe('sort', function () {
|
describe('sort', function () {
|
||||||
for (const length of [0, 1, 2, 8, 32, 128]) {
|
for (const length of [0, 1, 2, 8, 32, 128]) {
|
||||||
describe(`${type}[] of length ${length}`, function () {
|
describe(`${name}[] of length ${length}`, function () {
|
||||||
beforeEach(async function () {
|
beforeEach(async function () {
|
||||||
this.array = Array.from({ length }, generators[type]);
|
this.array = Array.from({ length }, generators[name]);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(async function () {
|
afterEach(async function () {
|
||||||
@ -174,6 +175,7 @@ describe('Arrays', function () {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
describe('unsafeAccess', function () {
|
describe('unsafeAccess', function () {
|
||||||
describe('storage', function () {
|
describe('storage', function () {
|
||||||
@ -197,7 +199,7 @@ describe('Arrays', function () {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('memory', function () {
|
describe('memory', function () {
|
||||||
const fragment = `$unsafeMemoryAccess(${type}[] arr, uint256 pos)`;
|
const fragment = `$unsafeMemoryAccess(${name}[] arr, uint256 pos)`;
|
||||||
|
|
||||||
for (const i in elements) {
|
for (const i in elements) {
|
||||||
it(`unsafeMemoryAccess within bounds #${i}`, async function () {
|
it(`unsafeMemoryAccess within bounds #${i}`, async function () {
|
||||||
@ -211,7 +213,9 @@ describe('Arrays', function () {
|
|||||||
|
|
||||||
it('unsafeMemoryAccess loop around', async function () {
|
it('unsafeMemoryAccess loop around', async function () {
|
||||||
for (let i = 251n; i < 256n; ++i) {
|
for (let i = 251n; i < 256n; ++i) {
|
||||||
expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(BigInt(elements.length));
|
expect(await this.mock[fragment](elements, 2n ** i - 1n)).to.equal(
|
||||||
|
isValueType ? BigInt(elements.length) : generators[name].zero,
|
||||||
|
);
|
||||||
expect(await this.mock[fragment](elements, 2n ** i + 0n)).to.equal(elements[0]);
|
expect(await this.mock[fragment](elements, 2n ** i + 0n)).to.equal(elements[0]);
|
||||||
expect(await this.mock[fragment](elements, 2n ** i + 1n)).to.equal(elements[1]);
|
expect(await this.mock[fragment](elements, 2n ** i + 1n)).to.equal(elements[1]);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user