From 441dc141ac99622de7e535fa75dfc74af939019c Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Tue, 4 Feb 2025 20:30:53 +0100 Subject: [PATCH] Add Bytes32x2Set (#5442) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ernesto GarcĂ­a --- .changeset/lucky-teachers-sip.md | 5 + .changeset/ten-peas-mix.md | 5 + contracts/utils/cryptography/Hashes.sol | 4 +- contracts/utils/structs/EnumerableSet.sol | 113 ++++++++++++++++ scripts/generate/templates/EnumerableSet.js | 121 +++++++++++++++++- .../generate/templates/EnumerableSet.opts.js | 14 +- test/utils/structs/EnumerableSet.test.js | 11 +- 7 files changed, 261 insertions(+), 12 deletions(-) create mode 100644 .changeset/lucky-teachers-sip.md create mode 100644 .changeset/ten-peas-mix.md diff --git a/.changeset/lucky-teachers-sip.md b/.changeset/lucky-teachers-sip.md new file mode 100644 index 000000000..fab22e266 --- /dev/null +++ b/.changeset/lucky-teachers-sip.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`EnumerableSet`: Add `Bytes32x2Set` that handles (ordered) pairs of bytes32. diff --git a/.changeset/ten-peas-mix.md b/.changeset/ten-peas-mix.md new file mode 100644 index 000000000..4e7ae24b0 --- /dev/null +++ b/.changeset/ten-peas-mix.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Hashes`: Expose `efficientKeccak256` for hashing non-commutative pairs of bytes32 without allocating extra memory. diff --git a/contracts/utils/cryptography/Hashes.sol b/contracts/utils/cryptography/Hashes.sol index 893883164..6b7168e87 100644 --- a/contracts/utils/cryptography/Hashes.sol +++ b/contracts/utils/cryptography/Hashes.sol @@ -15,13 +15,13 @@ library Hashes { * NOTE: Equivalent to the `standardNodeHash` in our https://github.com/OpenZeppelin/merkle-tree[JavaScript library]. */ function commutativeKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32) { - return a < b ? _efficientKeccak256(a, b) : _efficientKeccak256(b, a); + return a < b ? efficientKeccak256(a, b) : efficientKeccak256(b, a); } /** * @dev Implementation of keccak256(abi.encode(a, b)) that doesn't allocate or expand memory. */ - function _efficientKeccak256(bytes32 a, bytes32 b) private pure returns (bytes32 value) { + function efficientKeccak256(bytes32 a, bytes32 b) internal pure returns (bytes32 value) { assembly ("memory-safe") { mstore(0x00, a) mstore(0x20, b) diff --git a/contracts/utils/structs/EnumerableSet.sol b/contracts/utils/structs/EnumerableSet.sol index 065202e82..4aa27da17 100644 --- a/contracts/utils/structs/EnumerableSet.sol +++ b/contracts/utils/structs/EnumerableSet.sol @@ -4,6 +4,8 @@ pragma solidity ^0.8.20; +import {Hashes} from "../cryptography/Hashes.sol"; + /** * @dev Library for managing * https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive @@ -372,4 +374,115 @@ library EnumerableSet { return result; } + + struct Bytes32x2Set { + // Storage of set values + bytes32[2][] _values; + // Position is the index of the value in the `values` array plus 1. + // Position 0 is used to mean a value is not in the self. + mapping(bytes32 valueHash => uint256) _positions; + } + + /** + * @dev Add a value to a self. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ + function add(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) { + if (!contains(self, value)) { + self._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + self._positions[_hash(value)] = self._values.length; + return true; + } else { + return false; + } + } + + /** + * @dev Removes a value from a self. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ + function remove(Bytes32x2Set storage self, bytes32[2] memory value) internal returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + bytes32 valueHash = _hash(value); + uint256 position = self._positions[valueHash]; + + if (position != 0) { + // Equivalent to contains(self, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = self._values.length - 1; + + if (valueIndex != lastIndex) { + bytes32[2] memory lastValue = self._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + self._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + self._positions[_hash(lastValue)] = position; + } + + // Delete the slot where the moved value was stored + self._values.pop(); + + // Delete the tracked position for the deleted slot + delete self._positions[valueHash]; + + return true; + } else { + return false; + } + } + + /** + * @dev Returns true if the value is in the self. O(1). + */ + function contains(Bytes32x2Set storage self, bytes32[2] memory value) internal view returns (bool) { + return self._positions[_hash(value)] != 0; + } + + /** + * @dev Returns the number of values on the self. O(1). + */ + function length(Bytes32x2Set storage self) internal view returns (uint256) { + return self._values.length; + } + + /** + * @dev Returns the value stored at position `index` in the self. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - `index` must be strictly less than {length}. + */ + function at(Bytes32x2Set storage self, uint256 index) internal view returns (bytes32[2] memory) { + return self._values[index]; + } + + /** + * @dev Return the entire set in an array + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ + function values(Bytes32x2Set storage self) internal view returns (bytes32[2][] memory) { + return self._values; + } + + function _hash(bytes32[2] memory value) private pure returns (bytes32) { + return Hashes.efficientKeccak256(value[0], value[1]); + } } diff --git a/scripts/generate/templates/EnumerableSet.js b/scripts/generate/templates/EnumerableSet.js index 02eccd0df..cf9a1fc67 100644 --- a/scripts/generate/templates/EnumerableSet.js +++ b/scripts/generate/templates/EnumerableSet.js @@ -5,6 +5,8 @@ const { TYPES } = require('./EnumerableSet.opts'); const header = `\ pragma solidity ^0.8.20; +import {Hashes} from "../cryptography/Hashes.sol"; + /** * @dev Library for managing * https://en.wikipedia.org/wiki/Set_(abstract_data_type)[sets] of primitive @@ -233,6 +235,121 @@ function values(${name} storage set) internal view returns (${type}[] memory) { } `; +const memorySet = ({ name, type }) => `\ +struct ${name} { + // Storage of set values + ${type}[] _values; + // Position is the index of the value in the \`values\` array plus 1. + // Position 0 is used to mean a value is not in the self. + mapping(bytes32 valueHash => uint256) _positions; +} + +/** + * @dev Add a value to a self. O(1). + * + * Returns true if the value was added to the set, that is if it was not + * already present. + */ +function add(${name} storage self, ${type} memory value) internal returns (bool) { + if (!contains(self, value)) { + self._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + self._positions[_hash(value)] = self._values.length; + return true; + } else { + return false; + } +} + +/** + * @dev Removes a value from a self. O(1). + * + * Returns true if the value was removed from the set, that is if it was + * present. + */ +function remove(${name} storage self, ${type} memory value) internal returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + bytes32 valueHash = _hash(value); + uint256 position = self._positions[valueHash]; + + if (position != 0) { + // Equivalent to contains(self, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = self._values.length - 1; + + if (valueIndex != lastIndex) { + ${type} memory lastValue = self._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + self._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + self._positions[_hash(lastValue)] = position; + } + + // Delete the slot where the moved value was stored + self._values.pop(); + + // Delete the tracked position for the deleted slot + delete self._positions[valueHash]; + + return true; + } else { + return false; + } +} + +/** + * @dev Returns true if the value is in the self. O(1). + */ +function contains(${name} storage self, ${type} memory value) internal view returns (bool) { + return self._positions[_hash(value)] != 0; +} + +/** + * @dev Returns the number of values on the self. O(1). + */ +function length(${name} storage self) internal view returns (uint256) { + return self._values.length; +} + +/** + * @dev Returns the value stored at position \`index\` in the self. O(1). + * + * Note that there are no guarantees on the ordering of values inside the + * array, and it may change when more values are added or removed. + * + * Requirements: + * + * - \`index\` must be strictly less than {length}. + */ +function at(${name} storage self, uint256 index) internal view returns (${type} memory) { + return self._values[index]; +} + +/** + * @dev Return the entire set in an array + * + * WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed + * to mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that + * this function has an unbounded cost, and using it as part of a state-changing function may render the function + * uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + */ +function values(${name} storage self) internal view returns (${type}[] memory) { + return self._values; +} +`; + +const hashes = `\ +function _hash(bytes32[2] memory value) private pure returns (bytes32) { + return Hashes.efficientKeccak256(value[0], value[1]); +} +`; + // GENERATE module.exports = format( header.trimEnd(), @@ -240,7 +357,9 @@ module.exports = format( format( [].concat( defaultSet, - TYPES.map(details => customSet(details)), + TYPES.filter(({ size }) => size == undefined).map(details => customSet(details)), + TYPES.filter(({ size }) => size != undefined).map(details => memorySet(details)), + hashes, ), ).trimEnd(), '}', diff --git a/scripts/generate/templates/EnumerableSet.opts.js b/scripts/generate/templates/EnumerableSet.opts.js index 739f0acdf..a8173f3cf 100644 --- a/scripts/generate/templates/EnumerableSet.opts.js +++ b/scripts/generate/templates/EnumerableSet.opts.js @@ -1,12 +1,16 @@ const { capitalize } = require('../../helpers'); -const mapType = str => (str == 'uint256' ? 'Uint' : capitalize(str)); +const mapType = ({ type, size }) => [type == 'uint256' ? 'Uint' : capitalize(type), size].filter(Boolean).join('x'); -const formatType = type => ({ - name: `${mapType(type)}Set`, - type, +const formatType = ({ type, size = undefined }) => ({ + name: `${mapType({ type, size })}Set`, + type: size != undefined ? `${type}[${size}]` : type, + base: size != undefined ? type : undefined, + size, }); -const TYPES = ['bytes32', 'address', 'uint256'].map(formatType); +const TYPES = [{ type: 'bytes32' }, { type: 'bytes32', size: 2 }, { type: 'address' }, { type: 'uint256' }].map( + formatType, +); module.exports = { TYPES, formatType }; diff --git a/test/utils/structs/EnumerableSet.test.js b/test/utils/structs/EnumerableSet.test.js index 66d666058..cccc4a3e0 100644 --- a/test/utils/structs/EnumerableSet.test.js +++ b/test/utils/structs/EnumerableSet.test.js @@ -20,10 +20,13 @@ async function fixture() { const mock = await ethers.deployContract('$EnumerableSet'); const env = Object.fromEntries( - TYPES.map(({ name, type }) => [ + TYPES.map(({ name, type, base, size }) => [ type, { - values: Array.from({ length: 3 }, generators[type]), + values: Array.from( + { length: 3 }, + size ? () => Array.from({ length: size }, generators[base]) : generators[type], + ), methods: getMethods(mock, { add: `$add(uint256,${type})`, remove: `$remove(uint256,${type})`, @@ -33,8 +36,8 @@ async function fixture() { values: `$values_EnumerableSet_${name}(uint256)`, }), events: { - addReturn: `return$add_EnumerableSet_${name}_${type}`, - removeReturn: `return$remove_EnumerableSet_${name}_${type}`, + addReturn: `return$add_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`, + removeReturn: `return$remove_EnumerableSet_${name}_${type.replace(/[[\]]/g, '_')}`, }, }, ]),