Add non-value types in EnumerableSet and EnumerableMap (#5658)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
Ernesto García
2025-06-03 08:26:06 -06:00
committed by GitHub
parent 4bafedfe72
commit 784d4f71b1
13 changed files with 770 additions and 101 deletions

View File

@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---
`EnumerableMap`: Add support for `BytesToBytesMap` type.

View File

@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---
`EnumerableSet`: Add support for `StringSet` and `BytesSet` types.

View File

@ -39,6 +39,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
* - `address -> address` (`AddressToAddressMap`) since v5.1.0
* - `address -> bytes32` (`AddressToBytes32Map`) since v5.1.0
* - `bytes32 -> address` (`Bytes32ToAddressMap`) since v5.1.0
* - `bytes -> bytes` (`BytesToBytesMap`) since v5.4.0
*
* [WARNING]
* ====
@ -51,7 +52,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
* ====
*/
library EnumerableMap {
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableSet for *;
// To implement this library for multiple types with as little code repetition as possible, we write it in
// terms of a generic Map type with bytes32 keys and values. The Map implementation uses private functions,
@ -997,4 +998,122 @@ library EnumerableMap {
return result;
}
/**
* @dev Query for a nonexistent map key.
*/
error EnumerableMapNonexistentBytesKey(bytes key);
struct BytesToBytesMap {
// Storage of keys
EnumerableSet.BytesSet _keys;
mapping(bytes key => bytes) _values;
}
/**
* @dev Adds a key-value pair to a map, or updates the value for an existing
* key. O(1).
*
* Returns true if the key was added to the map, that is if it was not
* already present.
*/
function set(BytesToBytesMap storage map, bytes memory key, bytes memory value) internal returns (bool) {
map._values[key] = value;
return map._keys.add(key);
}
/**
* @dev Removes a key-value pair from a map. O(1).
*
* Returns true if the key was removed from the map, that is if it was present.
*/
function remove(BytesToBytesMap storage map, bytes memory key) internal returns (bool) {
delete map._values[key];
return map._keys.remove(key);
}
/**
* @dev Removes all the entries from a map. O(n).
*
* WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
* function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block.
*/
function clear(BytesToBytesMap storage map) internal {
uint256 len = length(map);
for (uint256 i = 0; i < len; ++i) {
delete map._values[map._keys.at(i)];
}
map._keys.clear();
}
/**
* @dev Returns true if the key is in the map. O(1).
*/
function contains(BytesToBytesMap storage map, bytes memory key) internal view returns (bool) {
return map._keys.contains(key);
}
/**
* @dev Returns the number of key-value pairs in the map. O(1).
*/
function length(BytesToBytesMap storage map) internal view returns (uint256) {
return map._keys.length();
}
/**
* @dev Returns the key-value pair stored at position `index` in the map. O(1).
*
* Note that there are no guarantees on the ordering of entries inside the
* array, and it may change when more entries are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function at(
BytesToBytesMap storage map,
uint256 index
) internal view returns (bytes memory key, bytes memory value) {
key = map._keys.at(index);
value = map._values[key];
}
/**
* @dev Tries to returns the value associated with `key`. O(1).
* Does not revert if `key` is not in the map.
*/
function tryGet(
BytesToBytesMap storage map,
bytes memory key
) internal view returns (bool exists, bytes memory value) {
value = map._values[key];
exists = bytes(value).length != 0 || contains(map, key);
}
/**
* @dev Returns the value associated with `key`. O(1).
*
* Requirements:
*
* - `key` must be in the map.
*/
function get(BytesToBytesMap storage map, bytes memory key) internal view returns (bytes memory value) {
bool exists;
(exists, value) = tryGet(map, key);
if (!exists) {
revert EnumerableMapNonexistentBytesKey(key);
}
}
/**
* @dev Return the an array containing all the keys
*
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function keys(BytesToBytesMap storage map) internal view returns (bytes[] memory) {
return map._keys.values();
}
}

View File

@ -28,8 +28,13 @@ import {Arrays} from "../Arrays.sol";
* }
* ```
*
* As of v3.3.0, sets of type `bytes32` (`Bytes32Set`), `address` (`AddressSet`)
* and `uint256` (`UintSet`) are supported.
* The following types are supported:
*
* - `bytes32` (`Bytes32Set`) since v3.3.0
* - `address` (`AddressSet`) since v3.3.0
* - `uint256` (`UintSet`) since v3.3.0
* - `string` (`StringSet`) since v5.4.0
* - `bytes` (`BytesSet`) since v5.4.0
*
* [WARNING]
* ====
@ -419,4 +424,244 @@ library EnumerableSet {
return result;
}
struct StringSet {
// Storage of set values
string[] _values;
// Position is the index of the value in the `values` array plus 1.
// Position 0 is used to mean a value is not in the set.
mapping(string value => uint256) _positions;
}
/**
* @dev Add a value to a set. O(1).
*
* Returns true if the value was added to the set, that is if it was not
* already present.
*/
function add(StringSet storage self, string memory value) internal returns (bool) {
if (!contains(self, value)) {
self._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
self._positions[value] = self._values.length;
return true;
} else {
return false;
}
}
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the value was removed from the set, that is if it was
* present.
*/
function remove(StringSet storage self, string memory value) internal returns (bool) {
// We cache the value's position to prevent multiple reads from the same storage slot
uint256 position = self._positions[value];
if (position != 0) {
// Equivalent to contains(self, value)
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
// the array, and then remove the last element (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.
uint256 valueIndex = position - 1;
uint256 lastIndex = self._values.length - 1;
if (valueIndex != lastIndex) {
string memory lastValue = self._values[lastIndex];
// Move the lastValue to the index where the value to delete is
self._values[valueIndex] = lastValue;
// Update the tracked position of the lastValue (that was just moved)
self._positions[lastValue] = position;
}
// Delete the slot where the moved value was stored
self._values.pop();
// Delete the tracked position for the deleted slot
delete self._positions[value];
return true;
} else {
return false;
}
}
/**
* @dev Removes all the values from a set. O(n).
*
* WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
* function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
*/
function clear(StringSet storage set) internal {
uint256 len = length(set);
for (uint256 i = 0; i < len; ++i) {
delete set._positions[set._values[i]];
}
Arrays.unsafeSetLength(set._values, 0);
}
/**
* @dev Returns true if the value is in the set. O(1).
*/
function contains(StringSet storage self, string memory value) internal view returns (bool) {
return self._positions[value] != 0;
}
/**
* @dev Returns the number of values on the set. O(1).
*/
function length(StringSet storage self) internal view returns (uint256) {
return self._values.length;
}
/**
* @dev Returns the value stored at position `index` in the set. O(1).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function at(StringSet storage self, uint256 index) internal view returns (string memory) {
return self._values[index];
}
/**
* @dev Return the entire set in an array
*
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function values(StringSet storage self) internal view returns (string[] memory) {
return self._values;
}
struct BytesSet {
// Storage of set values
bytes[] _values;
// Position is the index of the value in the `values` array plus 1.
// Position 0 is used to mean a value is not in the set.
mapping(bytes value => uint256) _positions;
}
/**
* @dev Add a value to a set. O(1).
*
* Returns true if the value was added to the set, that is if it was not
* already present.
*/
function add(BytesSet storage self, bytes memory value) internal returns (bool) {
if (!contains(self, value)) {
self._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
self._positions[value] = self._values.length;
return true;
} else {
return false;
}
}
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the value was removed from the set, that is if it was
* present.
*/
function remove(BytesSet storage self, bytes memory value) internal returns (bool) {
// We cache the value's position to prevent multiple reads from the same storage slot
uint256 position = self._positions[value];
if (position != 0) {
// Equivalent to contains(self, value)
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
// the array, and then remove the last element (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.
uint256 valueIndex = position - 1;
uint256 lastIndex = self._values.length - 1;
if (valueIndex != lastIndex) {
bytes memory lastValue = self._values[lastIndex];
// Move the lastValue to the index where the value to delete is
self._values[valueIndex] = lastValue;
// Update the tracked position of the lastValue (that was just moved)
self._positions[lastValue] = position;
}
// Delete the slot where the moved value was stored
self._values.pop();
// Delete the tracked position for the deleted slot
delete self._positions[value];
return true;
} else {
return false;
}
}
/**
* @dev Removes all the values from a set. O(n).
*
* WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
* function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
*/
function clear(BytesSet storage set) internal {
uint256 len = length(set);
for (uint256 i = 0; i < len; ++i) {
delete set._positions[set._values[i]];
}
Arrays.unsafeSetLength(set._values, 0);
}
/**
* @dev Returns true if the value is in the set. O(1).
*/
function contains(BytesSet storage self, bytes memory value) internal view returns (bool) {
return self._positions[value] != 0;
}
/**
* @dev Returns the number of values on the set. O(1).
*/
function length(BytesSet storage self) internal view returns (uint256) {
return self._values.length;
}
/**
* @dev Returns the value stored at position `index` in the set. O(1).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function at(BytesSet storage self, uint256 index) internal view returns (bytes memory) {
return self._values[index];
}
/**
* @dev Return the entire set in an array
*
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function values(BytesSet storage self) internal view returns (bytes[] memory) {
return self._values;
}
}

View File

@ -1,6 +1,6 @@
#!/usr/bin/env node
// const cp = require('child_process');
const cp = require('child_process');
const fs = require('fs');
const path = require('path');
const format = require('./format-lines');
@ -13,7 +13,7 @@ function getVersion(path) {
}
}
function generateFromTemplate(file, template, outputPrefix = '') {
function generateFromTemplate(file, template, outputPrefix = '', lint = false) {
const script = path.relative(path.join(__dirname, '../..'), __filename);
const input = path.join(path.dirname(script), template);
const output = path.join(outputPrefix, file);
@ -27,9 +27,12 @@ function generateFromTemplate(file, template, outputPrefix = '') {
);
fs.writeFileSync(output, content);
// cp.execFileSync('prettier', ['--write', output]);
lint && cp.execFileSync('prettier', ['--write', output]);
}
// Some templates needs to go through the linter after generation
const needsLinter = ['utils/structs/EnumerableMap.sol'];
// Contracts
for (const [file, template] of Object.entries({
'utils/cryptography/MerkleProof.sol': './templates/MerkleProof.js',
@ -45,7 +48,7 @@ for (const [file, template] of Object.entries({
'mocks/StorageSlotMock.sol': './templates/StorageSlotMock.js',
'mocks/TransientSlotMock.sol': './templates/TransientSlotMock.js',
})) {
generateFromTemplate(file, template, './contracts/');
generateFromTemplate(file, template, './contracts/', needsLinter.includes(file));
}
// Tests
@ -54,5 +57,5 @@ for (const [file, template] of Object.entries({
'utils/Packing.t.sol': './templates/Packing.t.js',
'utils/SlotDerivation.t.sol': './templates/SlotDerivation.t.js',
})) {
generateFromTemplate(file, template, './test/');
generateFromTemplate(file, template, './test/', needsLinter.includes(file));
}

View File

@ -0,0 +1,53 @@
const { capitalize, mapValues } = require('../../helpers');
const typeDescr = ({ type, size = 0, memory = false }) => {
memory |= size > 0;
const name = [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x');
const base = size ? type : undefined;
const typeFull = size ? `${type}[${size}]` : type;
const typeLoc = memory ? `${typeFull} memory` : typeFull;
return { name, type: typeFull, typeLoc, base, size, memory };
};
const toSetTypeDescr = value => ({
name: value.name + 'Set',
value,
});
const toMapTypeDescr = ({ key, value }) => ({
name: `${key.name}To${value.name}Map`,
keySet: toSetTypeDescr(key),
key,
value,
});
const SET_TYPES = [
{ type: 'bytes32' },
{ type: 'address' },
{ type: 'uint256' },
{ type: 'string', memory: true },
{ type: 'bytes', memory: true },
]
.map(typeDescr)
.map(toSetTypeDescr);
const MAP_TYPES = []
.concat(
// value type maps
['uint256', 'address', 'bytes32']
.flatMap((keyType, _, array) => array.map(valueType => ({ key: { type: keyType }, value: { type: valueType } })))
.slice(0, -1), // remove bytes32 → bytes32 (last one) that is already defined
// non-value type maps
{ key: { type: 'bytes', memory: true }, value: { type: 'bytes', memory: true } },
)
.map(entry => mapValues(entry, typeDescr))
.map(toMapTypeDescr);
module.exports = {
SET_TYPES,
MAP_TYPES,
typeDescr,
toSetTypeDescr,
toMapTypeDescr,
};

View File

@ -1,6 +1,6 @@
const format = require('../format-lines');
const { fromBytes32, toBytes32 } = require('./conversion');
const { TYPES } = require('./EnumerableMap.opts');
const { MAP_TYPES } = require('./Enumerable.opts');
const header = `\
pragma solidity ^0.8.20;
@ -40,6 +40,7 @@ import {EnumerableSet} from "./EnumerableSet.sol";
* - \`address -> address\` (\`AddressToAddressMap\`) since v5.1.0
* - \`address -> bytes32\` (\`AddressToBytes32Map\`) since v5.1.0
* - \`bytes32 -> address\` (\`Bytes32ToAddressMap\`) since v5.1.0
* - \`bytes -> bytes\` (\`BytesToBytesMap\`) since v5.4.0
*
* [WARNING]
* ====
@ -176,7 +177,7 @@ function keys(Bytes32ToBytes32Map storage map) internal view returns (bytes32[]
}
`;
const customMap = ({ name, keyType, valueType }) => `\
const customMap = ({ name, key, value }) => `\
// ${name}
struct ${name} {
@ -190,8 +191,8 @@ struct ${name} {
* Returns true if the key was added to the map, that is if it was not
* already present.
*/
function set(${name} storage map, ${keyType} key, ${valueType} value) internal returns (bool) {
return set(map._inner, ${toBytes32(keyType, 'key')}, ${toBytes32(valueType, 'value')});
function set(${name} storage map, ${key.type} key, ${value.type} value) internal returns (bool) {
return set(map._inner, ${toBytes32(key.type, 'key')}, ${toBytes32(value.type, 'value')});
}
/**
@ -199,8 +200,8 @@ function set(${name} storage map, ${keyType} key, ${valueType} value) internal r
*
* Returns true if the key was removed from the map, that is if it was present.
*/
function remove(${name} storage map, ${keyType} key) internal returns (bool) {
return remove(map._inner, ${toBytes32(keyType, 'key')});
function remove(${name} storage map, ${key.type} key) internal returns (bool) {
return remove(map._inner, ${toBytes32(key.type, 'key')});
}
/**
@ -216,8 +217,8 @@ function clear(${name} storage map) internal {
/**
* @dev Returns true if the key is in the map. O(1).
*/
function contains(${name} storage map, ${keyType} key) internal view returns (bool) {
return contains(map._inner, ${toBytes32(keyType, 'key')});
function contains(${name} storage map, ${key.type} key) internal view returns (bool) {
return contains(map._inner, ${toBytes32(key.type, 'key')});
}
/**
@ -236,18 +237,18 @@ function length(${name} storage map) internal view returns (uint256) {
*
* - \`index\` must be strictly less than {length}.
*/
function at(${name} storage map, uint256 index) internal view returns (${keyType} key, ${valueType} value) {
function at(${name} storage map, uint256 index) internal view returns (${key.type} key, ${value.type} value) {
(bytes32 atKey, bytes32 val) = at(map._inner, index);
return (${fromBytes32(keyType, 'atKey')}, ${fromBytes32(valueType, 'val')});
return (${fromBytes32(key.type, 'atKey')}, ${fromBytes32(value.type, 'val')});
}
/**
* @dev Tries to returns the value associated with \`key\`. O(1).
* Does not revert if \`key\` is not in the map.
*/
function tryGet(${name} storage map, ${keyType} key) internal view returns (bool exists, ${valueType} value) {
(bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(keyType, 'key')});
return (success, ${fromBytes32(valueType, 'val')});
function tryGet(${name} storage map, ${key.type} key) internal view returns (bool exists, ${value.type} value) {
(bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(key.type, 'key')});
return (success, ${fromBytes32(value.type, 'val')});
}
/**
@ -257,8 +258,8 @@ function tryGet(${name} storage map, ${keyType} key) internal view returns (bool
*
* - \`key\` must be in the map.
*/
function get(${name} storage map, ${keyType} key) internal view returns (${valueType}) {
return ${fromBytes32(valueType, `get(map._inner, ${toBytes32(keyType, 'key')})`)};
function get(${name} storage map, ${key.type} key) internal view returns (${value.type}) {
return ${fromBytes32(value.type, `get(map._inner, ${toBytes32(key.type, 'key')})`)};
}
/**
@ -269,9 +270,9 @@ function get(${name} storage map, ${keyType} key) internal view returns (${value
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function keys(${name} storage map) internal view returns (${keyType}[] memory) {
function keys(${name} storage map) internal view returns (${key.type}[] memory) {
bytes32[] memory store = keys(map._inner);
${keyType}[] memory result;
${key.type}[] memory result;
assembly ("memory-safe") {
result := store
@ -281,16 +282,137 @@ function keys(${name} storage map) internal view returns (${keyType}[] memory) {
}
`;
const memoryMap = ({ name, keySet, key, value }) => `\
/**
* @dev Query for a nonexistent map key.
*/
error EnumerableMapNonexistent${key.name}Key(${key.type} key);
struct ${name} {
// Storage of keys
EnumerableSet.${keySet.name} _keys;
mapping(${key.type} key => ${value.type}) _values;
}
/**
* @dev Adds a key-value pair to a map, or updates the value for an existing
* key. O(1).
*
* Returns true if the key was added to the map, that is if it was not
* already present.
*/
function set(${name} storage map, ${key.typeLoc} key, ${value.typeLoc} value) internal returns (bool) {
map._values[key] = value;
return map._keys.add(key);
}
/**
* @dev Removes a key-value pair from a map. O(1).
*
* Returns true if the key was removed from the map, that is if it was present.
*/
function remove(${name} storage map, ${key.typeLoc} key) internal returns (bool) {
delete map._values[key];
return map._keys.remove(key);
}
/**
* @dev Removes all the entries from a map. O(n).
*
* WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
* function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block.
*/
function clear(${name} storage map) internal {
uint256 len = length(map);
for (uint256 i = 0; i < len; ++i) {
delete map._values[map._keys.at(i)];
}
map._keys.clear();
}
/**
* @dev Returns true if the key is in the map. O(1).
*/
function contains(${name} storage map, ${key.typeLoc} key) internal view returns (bool) {
return map._keys.contains(key);
}
/**
* @dev Returns the number of key-value pairs in the map. O(1).
*/
function length(${name} storage map) internal view returns (uint256) {
return map._keys.length();
}
/**
* @dev Returns the key-value pair stored at position \`index\` in the map. O(1).
*
* Note that there are no guarantees on the ordering of entries inside the
* array, and it may change when more entries are added or removed.
*
* Requirements:
*
* - \`index\` must be strictly less than {length}.
*/
function at(
${name} storage map,
uint256 index
) internal view returns (${key.typeLoc} key, ${value.typeLoc} value) {
key = map._keys.at(index);
value = map._values[key];
}
/**
* @dev Tries to returns the value associated with \`key\`. O(1).
* Does not revert if \`key\` is not in the map.
*/
function tryGet(
${name} storage map,
${key.typeLoc} key
) internal view returns (bool exists, ${value.typeLoc} value) {
value = map._values[key];
exists = ${value.memory ? 'bytes(value).length != 0' : `value != ${value.type}(0)`} || contains(map, key);
}
/**
* @dev Returns the value associated with \`key\`. O(1).
*
* Requirements:
*
* - \`key\` must be in the map.
*/
function get(${name} storage map, ${key.typeLoc} key) internal view returns (${value.typeLoc} value) {
bool exists;
(exists, value) = tryGet(map, key);
if (!exists) {
revert EnumerableMapNonexistent${key.name}Key(key);
}
}
/**
* @dev Return the an array containing all the keys
*
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function keys(${name} storage map) internal view returns (${key.type}[] memory) {
return map._keys.values();
}
`;
// GENERATE
module.exports = format(
header.trimEnd(),
'library EnumerableMap {',
format(
[].concat(
'using EnumerableSet for EnumerableSet.Bytes32Set;',
'using EnumerableSet for *;',
'',
defaultMap,
TYPES.map(details => customMap(details)),
MAP_TYPES.filter(({ key, value }) => !(key.memory || value.memory)).map(customMap),
MAP_TYPES.filter(({ key, value }) => key.memory || value.memory).map(memoryMap),
),
).trimEnd(),
'}',

View File

@ -1,19 +0,0 @@
const { capitalize } = require('../../helpers');
const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
const formatType = (keyType, valueType) => ({
name: `${mapType(keyType)}To${mapType(valueType)}Map`,
keyType,
valueType,
});
const TYPES = ['uint256', 'address', 'bytes32']
.flatMap((key, _, array) => array.map(value => [key, value]))
.slice(0, -1) // remove bytes32 → byte32 (last one) that is already defined
.map(args => formatType(...args));
module.exports = {
TYPES,
formatType,
};

View File

@ -1,6 +1,6 @@
const format = require('../format-lines');
const { fromBytes32, toBytes32 } = require('./conversion');
const { TYPES } = require('./EnumerableSet.opts');
const { SET_TYPES } = require('./Enumerable.opts');
const header = `\
pragma solidity ^0.8.20;
@ -29,8 +29,13 @@ import {Arrays} from "../Arrays.sol";
* }
* \`\`\`
*
* As of v3.3.0, sets of type \`bytes32\` (\`Bytes32Set\`), \`address\` (\`AddressSet\`)
* and \`uint256\` (\`UintSet\`) are supported.
* The following types are supported:
*
* - \`bytes32\` (\`Bytes32Set\`) since v3.3.0
* - \`address\` (\`AddressSet\`) since v3.3.0
* - \`uint256\` (\`UintSet\`) since v3.3.0
* - \`string\` (\`StringSet\`) since v5.4.0
* - \`bytes\` (\`BytesSet\`) since v5.4.0
*
* [WARNING]
* ====
@ -44,6 +49,7 @@ import {Arrays} from "../Arrays.sol";
*/
`;
// NOTE: this should be deprecated in favor of a more native construction in v6.0
const defaultSet = `\
// To implement this library for multiple types with as little code
// repetition as possible, we write it in terms of a generic Set type with
@ -175,7 +181,8 @@ function _values(Set storage set) private view returns (bytes32[] memory) {
}
`;
const customSet = ({ name, type }) => `\
// NOTE: this should be deprecated in favor of a more native construction in v6.0
const customSet = ({ name, value: { type } }) => `\
// ${name}
struct ${name} {
@ -260,6 +267,128 @@ function values(${name} storage set) internal view returns (${type}[] memory) {
}
`;
const memorySet = ({ name, value }) => `\
struct ${name} {
// Storage of set values
${value.type}[] _values;
// Position is the index of the value in the \`values\` array plus 1.
// Position 0 is used to mean a value is not in the set.
mapping(${value.type} value => uint256) _positions;
}
/**
* @dev Add a value to a set. O(1).
*
* Returns true if the value was added to the set, that is if it was not
* already present.
*/
function add(${name} storage self, ${value.type} memory value) internal returns (bool) {
if (!contains(self, value)) {
self._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
self._positions[value] = self._values.length;
return true;
} else {
return false;
}
}
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the value was removed from the set, that is if it was
* present.
*/
function remove(${name} storage self, ${value.type} memory value) internal returns (bool) {
// We cache the value's position to prevent multiple reads from the same storage slot
uint256 position = self._positions[value];
if (position != 0) {
// Equivalent to contains(self, value)
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
// the array, and then remove the last element (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.
uint256 valueIndex = position - 1;
uint256 lastIndex = self._values.length - 1;
if (valueIndex != lastIndex) {
${value.type} memory lastValue = self._values[lastIndex];
// Move the lastValue to the index where the value to delete is
self._values[valueIndex] = lastValue;
// Update the tracked position of the lastValue (that was just moved)
self._positions[lastValue] = position;
}
// Delete the slot where the moved value was stored
self._values.pop();
// Delete the tracked position for the deleted slot
delete self._positions[value];
return true;
} else {
return false;
}
}
/**
* @dev Removes all the values from a set. O(n).
*
* WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the
* function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block.
*/
function clear(${name} storage set) internal {
uint256 len = length(set);
for (uint256 i = 0; i < len; ++i) {
delete set._positions[set._values[i]];
}
Arrays.unsafeSetLength(set._values, 0);
}
/**
* @dev Returns true if the value is in the set. O(1).
*/
function contains(${name} storage self, ${value.type} memory value) internal view returns (bool) {
return self._positions[value] != 0;
}
/**
* @dev Returns the number of values on the set. O(1).
*/
function length(${name} storage self) internal view returns (uint256) {
return self._values.length;
}
/**
* @dev Returns the value stored at position \`index\` in the set. O(1).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - \`index\` must be strictly less than {length}.
*/
function at(${name} storage self, uint256 index) internal view returns (${value.type} memory) {
return self._values[index];
}
/**
* @dev Return the entire set in an array
*
* WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed
* to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that
* this function has an unbounded cost, and using it as part of a state-changing function may render the function
* uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
*/
function values(${name} storage self) internal view returns (${value.type}[] memory) {
return self._values;
}
`;
// GENERATE
module.exports = format(
header.trimEnd(),
@ -267,7 +396,8 @@ module.exports = format(
format(
[].concat(
defaultSet,
TYPES.map(details => customSet(details)),
SET_TYPES.filter(({ value }) => !value.memory).map(customSet),
SET_TYPES.filter(({ value }) => value.memory).map(memorySet),
),
).trimEnd(),
'}',

View File

@ -1,12 +0,0 @@
const { capitalize } = require('../../helpers');
const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str));
const formatType = type => ({
name: `${mapType(type)}Set`,
type,
});
const TYPES = ['bytes32', 'address', 'uint256'].map(formatType);
module.exports = { TYPES, formatType };

View File

@ -176,7 +176,7 @@ function shouldBehaveLikeMap() {
.withArgs(
this.key?.memory || this.value?.memory
? this.keyB
: ethers.AbiCoder.defaultAbiCoder().encode([this.keyType], [this.keyB]),
: ethers.AbiCoder.defaultAbiCoder().encode([this.key.type], [this.keyB]),
);
});
});

View File

@ -3,43 +3,58 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
const { mapValues } = require('../../helpers/iterate');
const { generators } = require('../../helpers/random');
const { TYPES, formatType } = require('../../../scripts/generate/templates/EnumerableMap.opts');
const { MAP_TYPES, typeDescr, toMapTypeDescr } = require('../../../scripts/generate/templates/Enumerable.opts');
const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior');
// Add Bytes32ToBytes32Map that must be tested but is not part of the generated types.
TYPES.unshift(formatType('bytes32', 'bytes32'));
MAP_TYPES.unshift(toMapTypeDescr({ key: typeDescr({ type: 'bytes32' }), value: typeDescr({ type: 'bytes32' }) }));
async function fixture() {
const mock = await ethers.deployContract('$EnumerableMap');
const env = Object.fromEntries(
TYPES.map(({ name, keyType, valueType }) => [
MAP_TYPES.map(({ name, key, value }) => [
name,
{
keyType,
keys: Array.from({ length: 3 }, generators[keyType]),
values: Array.from({ length: 3 }, generators[valueType]),
zeroValue: generators[valueType].zero,
key,
value,
keys: Array.from({ length: 3 }, generators[key.type]),
values: Array.from({ length: 3 }, generators[value.type]),
zeroValue: generators[value.type].zero,
methods: mapValues(
{
set: `$set(uint256,${keyType},${valueType})`,
get: `$get_EnumerableMap_${name}(uint256,${keyType})`,
tryGet: `$tryGet_EnumerableMap_${name}(uint256,${keyType})`,
remove: `$remove_EnumerableMap_${name}(uint256,${keyType})`,
clear: `$clear_EnumerableMap_${name}(uint256)`,
length: `$length_EnumerableMap_${name}(uint256)`,
at: `$at_EnumerableMap_${name}(uint256,uint256)`,
contains: `$contains_EnumerableMap_${name}(uint256,${keyType})`,
keys: `$keys_EnumerableMap_${name}(uint256)`,
},
MAP_TYPES.filter(map => map.key.name == key.name).length == 1
? {
set: `$set(uint256,${key.type},${value.type})`,
get: `$get(uint256,${key.type})`,
tryGet: `$tryGet(uint256,${key.type})`,
remove: `$remove(uint256,${key.type})`,
contains: `$contains(uint256,${key.type})`,
clear: `$clear_EnumerableMap_${name}(uint256)`,
length: `$length_EnumerableMap_${name}(uint256)`,
at: `$at_EnumerableMap_${name}(uint256,uint256)`,
keys: `$keys_EnumerableMap_${name}(uint256)`,
}
: {
set: `$set(uint256,${key.type},${value.type})`,
get: `$get_EnumerableMap_${name}(uint256,${key.type})`,
tryGet: `$tryGet_EnumerableMap_${name}(uint256,${key.type})`,
remove: `$remove_EnumerableMap_${name}(uint256,${key.type})`,
contains: `$contains_EnumerableMap_${name}(uint256,${key.type})`,
clear: `$clear_EnumerableMap_${name}(uint256)`,
length: `$length_EnumerableMap_${name}(uint256)`,
at: `$at_EnumerableMap_${name}(uint256,uint256)`,
keys: `$keys_EnumerableMap_${name}(uint256)`,
},
fnSig =>
(...args) =>
mock.getFunction(fnSig)(0, ...args),
),
events: {
setReturn: `return$set_EnumerableMap_${name}_${keyType}_${valueType}`,
removeReturn: `return$remove_EnumerableMap_${name}_${keyType}`,
setReturn: `return$set_EnumerableMap_${name}_${key.type}_${value.type}`,
removeReturn: `return$remove_EnumerableMap_${name}_${key.type}`,
},
error: key.memory || value.memory ? `EnumerableMapNonexistent${key.name}Key` : `EnumerableMapNonexistentKey`,
},
]),
);
@ -52,8 +67,8 @@ describe('EnumerableMap', function () {
Object.assign(this, await loadFixture(fixture));
});
for (const { name } of TYPES) {
describe(name, function () {
for (const { name, key, value } of MAP_TYPES) {
describe(`${name} (enumerable map from ${key.type} to ${value.type})`, function () {
beforeEach(async function () {
Object.assign(this, this.env[name]);
[this.keyA, this.keyB, this.keyC] = this.keys;

View File

@ -3,39 +3,42 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
const { mapValues } = require('../../helpers/iterate');
const { generators } = require('../../helpers/random');
const { TYPES } = require('../../../scripts/generate/templates/EnumerableSet.opts');
const { SET_TYPES } = require('../../../scripts/generate/templates/Enumerable.opts');
const { shouldBehaveLikeSet } = require('./EnumerableSet.behavior');
const getMethods = (mock, fnSigs) => {
return mapValues(
const getMethods = (mock, fnSigs) =>
mapValues(
fnSigs,
fnSig =>
(...args) =>
mock.getFunction(fnSig)(0, ...args),
);
};
async function fixture() {
const mock = await ethers.deployContract('$EnumerableSet');
const env = Object.fromEntries(
TYPES.map(({ name, type }) => [
type,
SET_TYPES.map(({ name, value }) => [
name,
{
values: Array.from({ length: 3 }, generators[type]),
value,
values: Array.from(
{ length: 3 },
value.size ? () => Array.from({ length: value.size }, generators[value.base]) : generators[value.type],
),
methods: getMethods(mock, {
add: `$add(uint256,${type})`,
remove: `$remove(uint256,${type})`,
add: `$add(uint256,${value.type})`,
remove: `$remove(uint256,${value.type})`,
contains: `$contains(uint256,${value.type})`,
clear: `$clear_EnumerableSet_${name}(uint256)`,
contains: `$contains(uint256,${type})`,
length: `$length_EnumerableSet_${name}(uint256)`,
at: `$at_EnumerableSet_${name}(uint256,uint256)`,
values: `$values_EnumerableSet_${name}(uint256)`,
}),
events: {
addReturn: `return$add_EnumerableSet_${name}_${type}`,
removeReturn: `return$remove_EnumerableSet_${name}_${type}`,
addReturn: `return$add_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`,
removeReturn: `return$remove_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`,
},
},
]),
@ -49,10 +52,10 @@ describe('EnumerableSet', function () {
Object.assign(this, await loadFixture(fixture));
});
for (const { type } of TYPES) {
describe(type, function () {
for (const { name, value } of SET_TYPES) {
describe(`${name} (enumerable set of ${value.type})`, function () {
beforeEach(function () {
Object.assign(this, this.env[type]);
Object.assign(this, this.env[name]);
[this.valueA, this.valueB, this.valueC] = this.values;
});