From 784d4f71b1ee3faa5509ee86b586caf310d9d2be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ernesto=20Garc=C3=ADa?= Date: Tue, 3 Jun 2025 08:26:06 -0600 Subject: [PATCH] Add non-value types in EnumerableSet and EnumerableMap (#5658) Co-authored-by: Hadrien Croubois --- .changeset/long-hornets-mate.md | 5 + .changeset/pink-dolls-shop.md | 5 + contracts/utils/structs/EnumerableMap.sol | 121 ++++++++- contracts/utils/structs/EnumerableSet.sol | 249 +++++++++++++++++- scripts/generate/run.js | 13 +- scripts/generate/templates/Enumerable.opts.js | 53 ++++ scripts/generate/templates/EnumerableMap.js | 160 +++++++++-- .../generate/templates/EnumerableMap.opts.js | 19 -- scripts/generate/templates/EnumerableSet.js | 140 +++++++++- .../generate/templates/EnumerableSet.opts.js | 12 - test/utils/structs/EnumerableMap.behavior.js | 2 +- test/utils/structs/EnumerableMap.test.js | 59 +++-- test/utils/structs/EnumerableSet.test.js | 33 +-- 13 files changed, 770 insertions(+), 101 deletions(-) create mode 100644 .changeset/long-hornets-mate.md create mode 100644 .changeset/pink-dolls-shop.md create mode 100644 scripts/generate/templates/Enumerable.opts.js delete mode 100644 scripts/generate/templates/EnumerableMap.opts.js delete mode 100644 scripts/generate/templates/EnumerableSet.opts.js diff --git a/.changeset/long-hornets-mate.md b/.changeset/long-hornets-mate.md new file mode 100644 index 000000000..29f8f82a3 --- /dev/null +++ b/.changeset/long-hornets-mate.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`EnumerableMap`: Add support for `BytesToBytesMap` type. diff --git a/.changeset/pink-dolls-shop.md b/.changeset/pink-dolls-shop.md new file mode 100644 index 000000000..2b9fbf6c8 --- /dev/null +++ b/.changeset/pink-dolls-shop.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`EnumerableSet`: Add support for `StringSet` and `BytesSet` types. diff --git a/contracts/utils/structs/EnumerableMap.sol b/contracts/utils/structs/EnumerableMap.sol index 09fa498fc..1c67aacaf 100644 --- a/contracts/utils/structs/EnumerableMap.sol +++ b/contracts/utils/structs/EnumerableMap.sol @@ -39,6 +39,7 @@ import {EnumerableSet} from "./EnumerableSet.sol"; * - `address -> address` (`AddressToAddressMap`) since v5.1.0 * - `address -> bytes32` (`AddressToBytes32Map`) since v5.1.0 * - `bytes32 -> address` (`Bytes32ToAddressMap`) since v5.1.0 + * - `bytes -> bytes` (`BytesToBytesMap`) since v5.4.0 * * [WARNING] * ==== @@ -51,7 +52,7 @@ import {EnumerableSet} from "./EnumerableSet.sol"; * ==== */ library EnumerableMap { - using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for *; // To implement this library for multiple types with as little code repetition as possible, we write it in // terms of a generic Map type with bytes32 keys and values. The Map implementation uses private functions, @@ -997,4 +998,122 @@ library EnumerableMap { return result; } + + /** + * @dev Query for a nonexistent map key. + */ + error EnumerableMapNonexistentBytesKey(bytes key); + + struct BytesToBytesMap { + // Storage of keys + EnumerableSet.BytesSet _keys; + mapping(bytes key => bytes) _values; + } + + /** + * @dev Adds a key-value pair to a map, or updates the value for an existing + * key. O(1). + * + * Returns true if the key was added to the map, that is if it was not + * already present. + */ + function set(BytesToBytesMap storage map, bytes memory key, bytes memory value) internal returns (bool) { + map._values[key] = value; + return map._keys.add(key); + } + + /** + * @dev Removes a key-value pair from a map. O(1). + * + * Returns true if the key was removed from the map, that is if it was present. + */ + function remove(BytesToBytesMap storage map, bytes memory key) internal returns (bool) { + delete map._values[key]; + return map._keys.remove(key); + } + + /** + * @dev Removes all the entries from a map. O(n). + * + * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the + * function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block. + */ + function clear(BytesToBytesMap storage map) internal { + uint256 len = length(map); + for (uint256 i = 0; i < len; ++i) { + delete map._values[map._keys.at(i)]; + } + map._keys.clear(); + } + + /** + * @dev Returns true if the key is in the map. O(1). + */ + function contains(BytesToBytesMap storage map, bytes memory key) internal view returns (bool) { + return map._keys.contains(key); + } + + /** + * @dev Returns the number of key-value pairs in the map. O(1). + */ + function length(BytesToBytesMap storage map) internal view returns (uint256) { + return map._keys.length(); + } + + /** + * @dev Returns the key-value pair stored at position `index` in the map. O(1). + * + * Note that there are no guarantees on the ordering of entries inside the + * array, and it may change when more entries are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at( + BytesToBytesMap storage map, + uint256 index + ) internal view returns (bytes memory key, bytes memory value) { + key = map._keys.at(index); + value = map._values[key]; + } + + /** + * @dev Tries to returns the value associated with `key`. O(1). + * Does not revert if `key` is not in the map. + */ + function tryGet( + BytesToBytesMap storage map, + bytes memory key + ) internal view returns (bool exists, bytes memory value) { + value = map._values[key]; + exists = bytes(value).length != 0 || contains(map, key); + } + + /** + * @dev Returns the value associated with `key`. O(1). + * + * Requirements: + * + * - `key` must be in the map. + */ + function get(BytesToBytesMap storage map, bytes memory key) internal view returns (bytes memory value) { + bool exists; + (exists, value) = tryGet(map, key); + if (!exists) { + revert EnumerableMapNonexistentBytesKey(key); + } + } + + /** + * @dev Return the an array containing all the keys + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function keys(BytesToBytesMap storage map) internal view returns (bytes[] memory) { + return map._keys.values(); + } } diff --git a/contracts/utils/structs/EnumerableSet.sol b/contracts/utils/structs/EnumerableSet.sol index ec8bb3779..1c037241b 100644 --- a/contracts/utils/structs/EnumerableSet.sol +++ b/contracts/utils/structs/EnumerableSet.sol @@ -28,8 +28,13 @@ import {Arrays} from "../Arrays.sol"; * } * ``` * - * As of v3.3.0, sets of type `bytes32` (`Bytes32Set`), `address` (`AddressSet`) - * and `uint256` (`UintSet`) are supported. + * The following types are supported: + * + * - `bytes32` (`Bytes32Set`) since v3.3.0 + * - `address` (`AddressSet`) since v3.3.0 + * - `uint256` (`UintSet`) since v3.3.0 + * - `string` (`StringSet`) since v5.4.0 + * - `bytes` (`BytesSet`) since v5.4.0 * * [WARNING] * ==== @@ -419,4 +424,244 @@ library EnumerableSet { return result; } + + struct StringSet { + // Storage of set values + string[] _values; + // Position is the index of the value in the `values` array plus 1. + // Position 0 is used to mean a value is not in the set. + mapping(string value => uint256) _positions; + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(StringSet storage self, string memory value) internal returns (bool) { + if (!contains(self, value)) { + self._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + self._positions[value] = self._values.length; + return true; + } else { + return false; + } + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(StringSet storage self, string memory value) internal returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + uint256 position = self._positions[value]; + + if (position != 0) { + // Equivalent to contains(self, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = self._values.length - 1; + + if (valueIndex != lastIndex) { + string memory lastValue = self._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + self._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + self._positions[lastValue] = position; + } + + // Delete the slot where the moved value was stored + self._values.pop(); + + // Delete the tracked position for the deleted slot + delete self._positions[value]; + + return true; + } else { + return false; + } + } + + /** + * @dev Removes all the values from a set. O(n). + * + * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the + * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block. + */ + function clear(StringSet storage set) internal { + uint256 len = length(set); + for (uint256 i = 0; i < len; ++i) { + delete set._positions[set._values[i]]; + } + Arrays.unsafeSetLength(set._values, 0); + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(StringSet storage self, string memory value) internal view returns (bool) { + return self._positions[value] != 0; + } + + /** + * @dev Returns the number of values on the set. O(1). + */ + function length(StringSet storage self) internal view returns (uint256) { + return self._values.length; + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(StringSet storage self, uint256 index) internal view returns (string memory) { + return self._values[index]; + } + + /** + * @dev Return the entire set in an array + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function values(StringSet storage self) internal view returns (string[] memory) { + return self._values; + } + + struct BytesSet { + // Storage of set values + bytes[] _values; + // Position is the index of the value in the `values` array plus 1. + // Position 0 is used to mean a value is not in the set. + mapping(bytes value => uint256) _positions; + } + + /** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(BytesSet storage self, bytes memory value) internal returns (bool) { + if (!contains(self, value)) { + self._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + self._positions[value] = self._values.length; + return true; + } else { + return false; + } + } + + /** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(BytesSet storage self, bytes memory value) internal returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + uint256 position = self._positions[value]; + + if (position != 0) { + // Equivalent to contains(self, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = self._values.length - 1; + + if (valueIndex != lastIndex) { + bytes memory lastValue = self._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + self._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + self._positions[lastValue] = position; + } + + // Delete the slot where the moved value was stored + self._values.pop(); + + // Delete the tracked position for the deleted slot + delete self._positions[value]; + + return true; + } else { + return false; + } + } + + /** + * @dev Removes all the values from a set. O(n). + * + * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the + * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block. + */ + function clear(BytesSet storage set) internal { + uint256 len = length(set); + for (uint256 i = 0; i < len; ++i) { + delete set._positions[set._values[i]]; + } + Arrays.unsafeSetLength(set._values, 0); + } + + /** + * @dev Returns true if the value is in the set. O(1). + */ + function contains(BytesSet storage self, bytes memory value) internal view returns (bool) { + return self._positions[value] != 0; + } + + /** + * @dev Returns the number of values on the set. O(1). + */ + function length(BytesSet storage self) internal view returns (uint256) { + return self._values.length; + } + + /** + * @dev Returns the value stored at position `index` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(BytesSet storage self, uint256 index) internal view returns (bytes memory) { + return self._values[index]; + } + + /** + * @dev Return the entire set in an array + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function values(BytesSet storage self) internal view returns (bytes[] memory) { + return self._values; + } } diff --git a/scripts/generate/run.js b/scripts/generate/run.js index 6779c93f4..394bb3952 100755 --- a/scripts/generate/run.js +++ b/scripts/generate/run.js @@ -1,6 +1,6 @@ #!/usr/bin/env node -// const cp = require('child_process'); +const cp = require('child_process'); const fs = require('fs'); const path = require('path'); const format = require('./format-lines'); @@ -13,7 +13,7 @@ function getVersion(path) { } } -function generateFromTemplate(file, template, outputPrefix = '') { +function generateFromTemplate(file, template, outputPrefix = '', lint = false) { const script = path.relative(path.join(__dirname, '../..'), __filename); const input = path.join(path.dirname(script), template); const output = path.join(outputPrefix, file); @@ -27,9 +27,12 @@ function generateFromTemplate(file, template, outputPrefix = '') { ); fs.writeFileSync(output, content); - // cp.execFileSync('prettier', ['--write', output]); + lint && cp.execFileSync('prettier', ['--write', output]); } +// Some templates needs to go through the linter after generation +const needsLinter = ['utils/structs/EnumerableMap.sol']; + // Contracts for (const [file, template] of Object.entries({ 'utils/cryptography/MerkleProof.sol': './templates/MerkleProof.js', @@ -45,7 +48,7 @@ for (const [file, template] of Object.entries({ 'mocks/StorageSlotMock.sol': './templates/StorageSlotMock.js', 'mocks/TransientSlotMock.sol': './templates/TransientSlotMock.js', })) { - generateFromTemplate(file, template, './contracts/'); + generateFromTemplate(file, template, './contracts/', needsLinter.includes(file)); } // Tests @@ -54,5 +57,5 @@ for (const [file, template] of Object.entries({ 'utils/Packing.t.sol': './templates/Packing.t.js', 'utils/SlotDerivation.t.sol': './templates/SlotDerivation.t.js', })) { - generateFromTemplate(file, template, './test/'); + generateFromTemplate(file, template, './test/', needsLinter.includes(file)); } diff --git a/scripts/generate/templates/Enumerable.opts.js b/scripts/generate/templates/Enumerable.opts.js new file mode 100644 index 000000000..50c7349b3 --- /dev/null +++ b/scripts/generate/templates/Enumerable.opts.js @@ -0,0 +1,53 @@ +const { capitalize, mapValues } = require('../../helpers'); + +const typeDescr = ({ type, size = 0, memory = false }) => { + memory |= size > 0; + + const name = [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x'); + const base = size ? type : undefined; + const typeFull = size ? `${type}[${size}]` : type; + const typeLoc = memory ? `${typeFull} memory` : typeFull; + return { name, type: typeFull, typeLoc, base, size, memory }; +}; + +const toSetTypeDescr = value => ({ + name: value.name + 'Set', + value, +}); + +const toMapTypeDescr = ({ key, value }) => ({ + name: `${key.name}To${value.name}Map`, + keySet: toSetTypeDescr(key), + key, + value, +}); + +const SET_TYPES = [ + { type: 'bytes32' }, + { type: 'address' }, + { type: 'uint256' }, + { type: 'string', memory: true }, + { type: 'bytes', memory: true }, +] + .map(typeDescr) + .map(toSetTypeDescr); + +const MAP_TYPES = [] + .concat( + // value type maps + ['uint256', 'address', 'bytes32'] + .flatMap((keyType, _, array) => array.map(valueType => ({ key: { type: keyType }, value: { type: valueType } }))) + .slice(0, -1), // remove bytes32 → bytes32 (last one) that is already defined + // non-value type maps + { key: { type: 'bytes', memory: true }, value: { type: 'bytes', memory: true } }, + ) + .map(entry => mapValues(entry, typeDescr)) + .map(toMapTypeDescr); + +module.exports = { + SET_TYPES, + MAP_TYPES, + typeDescr, + toSetTypeDescr, + toMapTypeDescr, +}; diff --git a/scripts/generate/templates/EnumerableMap.js b/scripts/generate/templates/EnumerableMap.js index 284e5ac02..557421888 100644 --- a/scripts/generate/templates/EnumerableMap.js +++ b/scripts/generate/templates/EnumerableMap.js @@ -1,6 +1,6 @@ const format = require('../format-lines'); const { fromBytes32, toBytes32 } = require('./conversion'); -const { TYPES } = require('./EnumerableMap.opts'); +const { MAP_TYPES } = require('./Enumerable.opts'); const header = `\ pragma solidity ^0.8.20; @@ -40,6 +40,7 @@ import {EnumerableSet} from "./EnumerableSet.sol"; * - \`address -> address\` (\`AddressToAddressMap\`) since v5.1.0 * - \`address -> bytes32\` (\`AddressToBytes32Map\`) since v5.1.0 * - \`bytes32 -> address\` (\`Bytes32ToAddressMap\`) since v5.1.0 + * - \`bytes -> bytes\` (\`BytesToBytesMap\`) since v5.4.0 * * [WARNING] * ==== @@ -176,7 +177,7 @@ function keys(Bytes32ToBytes32Map storage map) internal view returns (bytes32[] } `; -const customMap = ({ name, keyType, valueType }) => `\ +const customMap = ({ name, key, value }) => `\ // ${name} struct ${name} { @@ -190,8 +191,8 @@ struct ${name} { * Returns true if the key was added to the map, that is if it was not * already present. */ -function set(${name} storage map, ${keyType} key, ${valueType} value) internal returns (bool) { - return set(map._inner, ${toBytes32(keyType, 'key')}, ${toBytes32(valueType, 'value')}); +function set(${name} storage map, ${key.type} key, ${value.type} value) internal returns (bool) { + return set(map._inner, ${toBytes32(key.type, 'key')}, ${toBytes32(value.type, 'value')}); } /** @@ -199,8 +200,8 @@ function set(${name} storage map, ${keyType} key, ${valueType} value) internal r * * Returns true if the key was removed from the map, that is if it was present. */ -function remove(${name} storage map, ${keyType} key) internal returns (bool) { - return remove(map._inner, ${toBytes32(keyType, 'key')}); +function remove(${name} storage map, ${key.type} key) internal returns (bool) { + return remove(map._inner, ${toBytes32(key.type, 'key')}); } /** @@ -216,8 +217,8 @@ function clear(${name} storage map) internal { /** * @dev Returns true if the key is in the map. O(1). */ -function contains(${name} storage map, ${keyType} key) internal view returns (bool) { - return contains(map._inner, ${toBytes32(keyType, 'key')}); +function contains(${name} storage map, ${key.type} key) internal view returns (bool) { + return contains(map._inner, ${toBytes32(key.type, 'key')}); } /** @@ -236,18 +237,18 @@ function length(${name} storage map) internal view returns (uint256) { * * - \`index\` must be strictly less than {length}. */ -function at(${name} storage map, uint256 index) internal view returns (${keyType} key, ${valueType} value) { +function at(${name} storage map, uint256 index) internal view returns (${key.type} key, ${value.type} value) { (bytes32 atKey, bytes32 val) = at(map._inner, index); - return (${fromBytes32(keyType, 'atKey')}, ${fromBytes32(valueType, 'val')}); + return (${fromBytes32(key.type, 'atKey')}, ${fromBytes32(value.type, 'val')}); } /** * @dev Tries to returns the value associated with \`key\`. O(1). * Does not revert if \`key\` is not in the map. */ -function tryGet(${name} storage map, ${keyType} key) internal view returns (bool exists, ${valueType} value) { - (bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(keyType, 'key')}); - return (success, ${fromBytes32(valueType, 'val')}); +function tryGet(${name} storage map, ${key.type} key) internal view returns (bool exists, ${value.type} value) { + (bool success, bytes32 val) = tryGet(map._inner, ${toBytes32(key.type, 'key')}); + return (success, ${fromBytes32(value.type, 'val')}); } /** @@ -257,8 +258,8 @@ function tryGet(${name} storage map, ${keyType} key) internal view returns (bool * * - \`key\` must be in the map. */ -function get(${name} storage map, ${keyType} key) internal view returns (${valueType}) { - return ${fromBytes32(valueType, `get(map._inner, ${toBytes32(keyType, 'key')})`)}; +function get(${name} storage map, ${key.type} key) internal view returns (${value.type}) { + return ${fromBytes32(value.type, `get(map._inner, ${toBytes32(key.type, 'key')})`)}; } /** @@ -269,9 +270,9 @@ function get(${name} storage map, ${keyType} key) internal view returns (${value * this function has an unbounded cost, and using it as part of a state-changing function may render the function * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block. */ -function keys(${name} storage map) internal view returns (${keyType}[] memory) { +function keys(${name} storage map) internal view returns (${key.type}[] memory) { bytes32[] memory store = keys(map._inner); - ${keyType}[] memory result; + ${key.type}[] memory result; assembly ("memory-safe") { result := store @@ -281,16 +282,137 @@ function keys(${name} storage map) internal view returns (${keyType}[] memory) { } `; +const memoryMap = ({ name, keySet, key, value }) => `\ +/** + * @dev Query for a nonexistent map key. + */ +error EnumerableMapNonexistent${key.name}Key(${key.type} key); + +struct ${name} { + // Storage of keys + EnumerableSet.${keySet.name} _keys; + mapping(${key.type} key => ${value.type}) _values; +} + +/** + * @dev Adds a key-value pair to a map, or updates the value for an existing + * key. O(1). + * + * Returns true if the key was added to the map, that is if it was not + * already present. + */ +function set(${name} storage map, ${key.typeLoc} key, ${value.typeLoc} value) internal returns (bool) { + map._values[key] = value; + return map._keys.add(key); +} + +/** + * @dev Removes a key-value pair from a map. O(1). + * + * Returns true if the key was removed from the map, that is if it was present. + */ +function remove(${name} storage map, ${key.typeLoc} key) internal returns (bool) { + delete map._values[key]; + return map._keys.remove(key); +} + +/** + * @dev Removes all the entries from a map. O(n). + * + * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the + * function uncallable if the map grows to the point where clearing it consumes too much gas to fit in a block. + */ +function clear(${name} storage map) internal { + uint256 len = length(map); + for (uint256 i = 0; i < len; ++i) { + delete map._values[map._keys.at(i)]; + } + map._keys.clear(); +} + +/** + * @dev Returns true if the key is in the map. O(1). + */ +function contains(${name} storage map, ${key.typeLoc} key) internal view returns (bool) { + return map._keys.contains(key); +} + +/** + * @dev Returns the number of key-value pairs in the map. O(1). + */ +function length(${name} storage map) internal view returns (uint256) { + return map._keys.length(); +} + +/** + * @dev Returns the key-value pair stored at position \`index\` in the map. O(1). + * + * Note that there are no guarantees on the ordering of entries inside the + * array, and it may change when more entries are added or removed. + * + * Requirements: + * + * - \`index\` must be strictly less than {length}. + */ +function at( + ${name} storage map, + uint256 index +) internal view returns (${key.typeLoc} key, ${value.typeLoc} value) { + key = map._keys.at(index); + value = map._values[key]; +} + +/** + * @dev Tries to returns the value associated with \`key\`. O(1). + * Does not revert if \`key\` is not in the map. + */ +function tryGet( + ${name} storage map, + ${key.typeLoc} key +) internal view returns (bool exists, ${value.typeLoc} value) { + value = map._values[key]; + exists = ${value.memory ? 'bytes(value).length != 0' : `value != ${value.type}(0)`} || contains(map, key); +} + +/** + * @dev Returns the value associated with \`key\`. O(1). + * + * Requirements: + * + * - \`key\` must be in the map. + */ +function get(${name} storage map, ${key.typeLoc} key) internal view returns (${value.typeLoc} value) { + bool exists; + (exists, value) = tryGet(map, key); + if (!exists) { + revert EnumerableMapNonexistent${key.name}Key(key); + } +} + +/** + * @dev Return the an array containing all the keys + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the map grows to a point where copying to memory consumes too much gas to fit in a block. + */ +function keys(${name} storage map) internal view returns (${key.type}[] memory) { + return map._keys.values(); +} +`; + // GENERATE module.exports = format( header.trimEnd(), 'library EnumerableMap {', format( [].concat( - 'using EnumerableSet for EnumerableSet.Bytes32Set;', + 'using EnumerableSet for *;', '', defaultMap, - TYPES.map(details => customMap(details)), + MAP_TYPES.filter(({ key, value }) => !(key.memory || value.memory)).map(customMap), + MAP_TYPES.filter(({ key, value }) => key.memory || value.memory).map(memoryMap), ), ).trimEnd(), '}', diff --git a/scripts/generate/templates/EnumerableMap.opts.js b/scripts/generate/templates/EnumerableMap.opts.js deleted file mode 100644 index d26ab05b2..000000000 --- a/scripts/generate/templates/EnumerableMap.opts.js +++ /dev/null @@ -1,19 +0,0 @@ -const { capitalize } = require('../../helpers'); - -const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str)); - -const formatType = (keyType, valueType) => ({ - name: `${mapType(keyType)}To${mapType(valueType)}Map`, - keyType, - valueType, -}); - -const TYPES = ['uint256', 'address', 'bytes32'] - .flatMap((key, _, array) => array.map(value => [key, value])) - .slice(0, -1) // remove bytes32 → byte32 (last one) that is already defined - .map(args => formatType(...args)); - -module.exports = { - TYPES, - formatType, -}; diff --git a/scripts/generate/templates/EnumerableSet.js b/scripts/generate/templates/EnumerableSet.js index 3169d6a46..ac620b88a 100644 --- a/scripts/generate/templates/EnumerableSet.js +++ b/scripts/generate/templates/EnumerableSet.js @@ -1,6 +1,6 @@ const format = require('../format-lines'); const { fromBytes32, toBytes32 } = require('./conversion'); -const { TYPES } = require('./EnumerableSet.opts'); +const { SET_TYPES } = require('./Enumerable.opts'); const header = `\ pragma solidity ^0.8.20; @@ -29,8 +29,13 @@ import {Arrays} from "../Arrays.sol"; * } * \`\`\` * - * As of v3.3.0, sets of type \`bytes32\` (\`Bytes32Set\`), \`address\` (\`AddressSet\`) - * and \`uint256\` (\`UintSet\`) are supported. + * The following types are supported: + * + * - \`bytes32\` (\`Bytes32Set\`) since v3.3.0 + * - \`address\` (\`AddressSet\`) since v3.3.0 + * - \`uint256\` (\`UintSet\`) since v3.3.0 + * - \`string\` (\`StringSet\`) since v5.4.0 + * - \`bytes\` (\`BytesSet\`) since v5.4.0 * * [WARNING] * ==== @@ -44,6 +49,7 @@ import {Arrays} from "../Arrays.sol"; */ `; +// NOTE: this should be deprecated in favor of a more native construction in v6.0 const defaultSet = `\ // To implement this library for multiple types with as little code // repetition as possible, we write it in terms of a generic Set type with @@ -175,7 +181,8 @@ function _values(Set storage set) private view returns (bytes32[] memory) { } `; -const customSet = ({ name, type }) => `\ +// NOTE: this should be deprecated in favor of a more native construction in v6.0 +const customSet = ({ name, value: { type } }) => `\ // ${name} struct ${name} { @@ -260,6 +267,128 @@ function values(${name} storage set) internal view returns (${type}[] memory) { } `; +const memorySet = ({ name, value }) => `\ +struct ${name} { + // Storage of set values + ${value.type}[] _values; + // Position is the index of the value in the \`values\` array plus 1. + // Position 0 is used to mean a value is not in the set. + mapping(${value.type} value => uint256) _positions; +} + +/** + * @dev Add a value to a set. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ +function add(${name} storage self, ${value.type} memory value) internal returns (bool) { + if (!contains(self, value)) { + self._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + self._positions[value] = self._values.length; + return true; + } else { + return false; + } +} + +/** + * @dev Removes a value from a set. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ +function remove(${name} storage self, ${value.type} memory value) internal returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + uint256 position = self._positions[value]; + + if (position != 0) { + // Equivalent to contains(self, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = self._values.length - 1; + + if (valueIndex != lastIndex) { + ${value.type} memory lastValue = self._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + self._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + self._positions[lastValue] = position; + } + + // Delete the slot where the moved value was stored + self._values.pop(); + + // Delete the tracked position for the deleted slot + delete self._positions[value]; + + return true; + } else { + return false; + } +} + +/** + * @dev Removes all the values from a set. O(n). + * + * WARNING: Developers should keep in mind that this function has an unbounded cost and using it may render the + * function uncallable if the set grows to the point where clearing it consumes too much gas to fit in a block. + */ +function clear(${name} storage set) internal { + uint256 len = length(set); + for (uint256 i = 0; i < len; ++i) { + delete set._positions[set._values[i]]; + } + Arrays.unsafeSetLength(set._values, 0); +} + +/** + * @dev Returns true if the value is in the set. O(1). + */ +function contains(${name} storage self, ${value.type} memory value) internal view returns (bool) { + return self._positions[value] != 0; +} + +/** + * @dev Returns the number of values on the set. O(1). + */ +function length(${name} storage self) internal view returns (uint256) { + return self._values.length; +} + +/** + * @dev Returns the value stored at position \`index\` in the set. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - \`index\` must be strictly less than {length}. + */ +function at(${name} storage self, uint256 index) internal view returns (${value.type} memory) { + return self._values[index]; +} + +/** + * @dev Return the entire set in an array + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ +function values(${name} storage self) internal view returns (${value.type}[] memory) { + return self._values; +} +`; + // GENERATE module.exports = format( header.trimEnd(), @@ -267,7 +396,8 @@ module.exports = format( format( [].concat( defaultSet, - TYPES.map(details => customSet(details)), + SET_TYPES.filter(({ value }) => !value.memory).map(customSet), + SET_TYPES.filter(({ value }) => value.memory).map(memorySet), ), ).trimEnd(), '}', diff --git a/scripts/generate/templates/EnumerableSet.opts.js b/scripts/generate/templates/EnumerableSet.opts.js deleted file mode 100644 index 739f0acdf..000000000 --- a/scripts/generate/templates/EnumerableSet.opts.js +++ /dev/null @@ -1,12 +0,0 @@ -const { capitalize } = require('../../helpers'); - -const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str)); - -const formatType = type => ({ - name: `${mapType(type)}Set`, - type, -}); - -const TYPES = ['bytes32', 'address', 'uint256'].map(formatType); - -module.exports = { TYPES, formatType }; diff --git a/test/utils/structs/EnumerableMap.behavior.js b/test/utils/structs/EnumerableMap.behavior.js index cf8de2e9d..11806dae6 100644 --- a/test/utils/structs/EnumerableMap.behavior.js +++ b/test/utils/structs/EnumerableMap.behavior.js @@ -176,7 +176,7 @@ function shouldBehaveLikeMap() { .withArgs( this.key?.memory || this.value?.memory ? this.keyB - : ethers.AbiCoder.defaultAbiCoder().encode([this.keyType], [this.keyB]), + : ethers.AbiCoder.defaultAbiCoder().encode([this.key.type], [this.keyB]), ); }); }); diff --git a/test/utils/structs/EnumerableMap.test.js b/test/utils/structs/EnumerableMap.test.js index cb4b77a65..7319ba38d 100644 --- a/test/utils/structs/EnumerableMap.test.js +++ b/test/utils/structs/EnumerableMap.test.js @@ -3,43 +3,58 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { mapValues } = require('../../helpers/iterate'); const { generators } = require('../../helpers/random'); -const { TYPES, formatType } = require('../../../scripts/generate/templates/EnumerableMap.opts'); +const { MAP_TYPES, typeDescr, toMapTypeDescr } = require('../../../scripts/generate/templates/Enumerable.opts'); const { shouldBehaveLikeMap } = require('./EnumerableMap.behavior'); // Add Bytes32ToBytes32Map that must be tested but is not part of the generated types. -TYPES.unshift(formatType('bytes32', 'bytes32')); +MAP_TYPES.unshift(toMapTypeDescr({ key: typeDescr({ type: 'bytes32' }), value: typeDescr({ type: 'bytes32' }) })); async function fixture() { const mock = await ethers.deployContract('$EnumerableMap'); + const env = Object.fromEntries( - TYPES.map(({ name, keyType, valueType }) => [ + MAP_TYPES.map(({ name, key, value }) => [ name, { - keyType, - keys: Array.from({ length: 3 }, generators[keyType]), - values: Array.from({ length: 3 }, generators[valueType]), - zeroValue: generators[valueType].zero, + key, + value, + keys: Array.from({ length: 3 }, generators[key.type]), + values: Array.from({ length: 3 }, generators[value.type]), + zeroValue: generators[value.type].zero, methods: mapValues( - { - set: `$set(uint256,${keyType},${valueType})`, - get: `$get_EnumerableMap_${name}(uint256,${keyType})`, - tryGet: `$tryGet_EnumerableMap_${name}(uint256,${keyType})`, - remove: `$remove_EnumerableMap_${name}(uint256,${keyType})`, - clear: `$clear_EnumerableMap_${name}(uint256)`, - length: `$length_EnumerableMap_${name}(uint256)`, - at: `$at_EnumerableMap_${name}(uint256,uint256)`, - contains: `$contains_EnumerableMap_${name}(uint256,${keyType})`, - keys: `$keys_EnumerableMap_${name}(uint256)`, - }, + MAP_TYPES.filter(map => map.key.name == key.name).length == 1 + ? { + set: `$set(uint256,${key.type},${value.type})`, + get: `$get(uint256,${key.type})`, + tryGet: `$tryGet(uint256,${key.type})`, + remove: `$remove(uint256,${key.type})`, + contains: `$contains(uint256,${key.type})`, + clear: `$clear_EnumerableMap_${name}(uint256)`, + length: `$length_EnumerableMap_${name}(uint256)`, + at: `$at_EnumerableMap_${name}(uint256,uint256)`, + keys: `$keys_EnumerableMap_${name}(uint256)`, + } + : { + set: `$set(uint256,${key.type},${value.type})`, + get: `$get_EnumerableMap_${name}(uint256,${key.type})`, + tryGet: `$tryGet_EnumerableMap_${name}(uint256,${key.type})`, + remove: `$remove_EnumerableMap_${name}(uint256,${key.type})`, + contains: `$contains_EnumerableMap_${name}(uint256,${key.type})`, + clear: `$clear_EnumerableMap_${name}(uint256)`, + length: `$length_EnumerableMap_${name}(uint256)`, + at: `$at_EnumerableMap_${name}(uint256,uint256)`, + keys: `$keys_EnumerableMap_${name}(uint256)`, + }, fnSig => (...args) => mock.getFunction(fnSig)(0, ...args), ), events: { - setReturn: `return$set_EnumerableMap_${name}_${keyType}_${valueType}`, - removeReturn: `return$remove_EnumerableMap_${name}_${keyType}`, + setReturn: `return$set_EnumerableMap_${name}_${key.type}_${value.type}`, + removeReturn: `return$remove_EnumerableMap_${name}_${key.type}`, }, + error: key.memory || value.memory ? `EnumerableMapNonexistent${key.name}Key` : `EnumerableMapNonexistentKey`, }, ]), ); @@ -52,8 +67,8 @@ describe('EnumerableMap', function () { Object.assign(this, await loadFixture(fixture)); }); - for (const { name } of TYPES) { - describe(name, function () { + for (const { name, key, value } of MAP_TYPES) { + describe(`${name} (enumerable map from ${key.type} to ${value.type})`, function () { beforeEach(async function () { Object.assign(this, this.env[name]); [this.keyA, this.keyB, this.keyC] = this.keys; diff --git a/test/utils/structs/EnumerableSet.test.js b/test/utils/structs/EnumerableSet.test.js index 1f92727a4..e111d2197 100644 --- a/test/utils/structs/EnumerableSet.test.js +++ b/test/utils/structs/EnumerableSet.test.js @@ -3,39 +3,42 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { mapValues } = require('../../helpers/iterate'); const { generators } = require('../../helpers/random'); -const { TYPES } = require('../../../scripts/generate/templates/EnumerableSet.opts'); +const { SET_TYPES } = require('../../../scripts/generate/templates/Enumerable.opts'); const { shouldBehaveLikeSet } = require('./EnumerableSet.behavior'); -const getMethods = (mock, fnSigs) => { - return mapValues( +const getMethods = (mock, fnSigs) => + mapValues( fnSigs, fnSig => (...args) => mock.getFunction(fnSig)(0, ...args), ); -}; async function fixture() { const mock = await ethers.deployContract('$EnumerableSet'); const env = Object.fromEntries( - TYPES.map(({ name, type }) => [ - type, + SET_TYPES.map(({ name, value }) => [ + name, { - values: Array.from({ length: 3 }, generators[type]), + value, + values: Array.from( + { length: 3 }, + value.size ? () => Array.from({ length: value.size }, generators[value.base]) : generators[value.type], + ), methods: getMethods(mock, { - add: `$add(uint256,${type})`, - remove: `$remove(uint256,${type})`, + add: `$add(uint256,${value.type})`, + remove: `$remove(uint256,${value.type})`, + contains: `$contains(uint256,${value.type})`, clear: `$clear_EnumerableSet_${name}(uint256)`, - contains: `$contains(uint256,${type})`, length: `$length_EnumerableSet_${name}(uint256)`, at: `$at_EnumerableSet_${name}(uint256,uint256)`, values: `$values_EnumerableSet_${name}(uint256)`, }), events: { - addReturn: `return$add_EnumerableSet_${name}_${type}`, - removeReturn: `return$remove_EnumerableSet_${name}_${type}`, + addReturn: `return$add_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`, + removeReturn: `return$remove_EnumerableSet_${name}_${value.type.replace(/[[\]]/g, '_')}`, }, }, ]), @@ -49,10 +52,10 @@ describe('EnumerableSet', function () { Object.assign(this, await loadFixture(fixture)); }); - for (const { type } of TYPES) { - describe(type, function () { + for (const { name, value } of SET_TYPES) { + describe(`${name} (enumerable set of ${value.type})`, function () { beforeEach(function () { - Object.assign(this, this.env[type]); + Object.assign(this, this.env[name]); [this.valueA, this.valueB, this.valueC] = this.values; });