diff --git a/contracts/utils/SlotDerivation.sol b/contracts/utils/SlotDerivation.sol index 62c28a55f..c248edc01 100644 --- a/contracts/utils/SlotDerivation.sol +++ b/contracts/utils/SlotDerivation.sol @@ -70,7 +70,7 @@ library SlotDerivation { */ function deriveMapping(bytes32 slot, address key) internal pure returns (bytes32 result) { assembly ("memory-safe") { - mstore(0x00, key) + mstore(0x00, and(key, shr(96, not(0)))) mstore(0x20, slot) result := keccak256(0x00, 0x40) } @@ -81,7 +81,7 @@ library SlotDerivation { */ function deriveMapping(bytes32 slot, bool key) internal pure returns (bytes32 result) { assembly ("memory-safe") { - mstore(0x00, key) + mstore(0x00, iszero(iszero(key))) mstore(0x20, slot) result := keccak256(0x00, 0x40) } diff --git a/scripts/generate/helpers/sanitize.js b/scripts/generate/helpers/sanitize.js new file mode 100644 index 000000000..e680ec1bf --- /dev/null +++ b/scripts/generate/helpers/sanitize.js @@ -0,0 +1,5 @@ +module.exports = { + address: expr => `and(${expr}, shr(96, not(0)))`, + bool: expr => `iszero(iszero(${expr}))`, + bytes: (expr, size) => `and(${expr}, shl(${256 - 8 * size}, not(0)))`, +}; diff --git a/scripts/generate/templates/Packing.js b/scripts/generate/templates/Packing.js index b9422ef48..d841c2f81 100644 --- a/scripts/generate/templates/Packing.js +++ b/scripts/generate/templates/Packing.js @@ -1,4 +1,5 @@ const format = require('../format-lines'); +const sanitize = require('../helpers/sanitize'); const { product } = require('../../helpers'); const { SIZES } = require('./Packing.opts'); @@ -44,8 +45,8 @@ function pack_${left}_${right}(bytes${left} left, bytes${right} right) internal left + right } result) { assembly ("memory-safe") { - left := and(left, shl(${256 - 8 * left}, not(0))) - right := and(right, shl(${256 - 8 * right}, not(0))) + left := ${sanitize.bytes('left', left)} + right := ${sanitize.bytes('right', right)} result := or(left, shr(${8 * left}, right)) } } @@ -55,7 +56,7 @@ const extract = (outer, inner) => `\ function extract_${outer}_${inner}(bytes${outer} self, uint8 offset) internal pure returns (bytes${inner} result) { if (offset > ${outer - inner}) revert OutOfRangeAccess(); assembly ("memory-safe") { - result := and(shl(mul(8, offset), self), shl(${256 - 8 * inner}, not(0))) + result := ${sanitize.bytes('shl(mul(8, offset), self)', inner)} } } `; @@ -64,7 +65,7 @@ const replace = (outer, inner) => `\ function replace_${outer}_${inner}(bytes${outer} self, bytes${inner} value, uint8 offset) internal pure returns (bytes${outer} result) { bytes${inner} oldValue = extract_${outer}_${inner}(self, offset); assembly ("memory-safe") { - value := and(value, shl(${256 - 8 * inner}, not(0))) + value := ${sanitize.bytes('value', inner)} result := xor(self, shr(mul(8, offset), xor(oldValue, value))) } } diff --git a/scripts/generate/templates/Slot.opts.js b/scripts/generate/templates/Slot.opts.js index aed1f9888..3eca2bcf0 100644 --- a/scripts/generate/templates/Slot.opts.js +++ b/scripts/generate/templates/Slot.opts.js @@ -10,4 +10,6 @@ const TYPES = [ { type: 'bytes', isValueType: false }, ].map(type => Object.assign(type, { name: type.name ?? capitalize(type.type) })); +Object.assign(TYPES, Object.fromEntries(TYPES.map(entry => [entry.type, entry]))); + module.exports = { TYPES }; diff --git a/scripts/generate/templates/SlotDerivation.js b/scripts/generate/templates/SlotDerivation.js index 5311fb3c8..39d0d9e35 100644 --- a/scripts/generate/templates/SlotDerivation.js +++ b/scripts/generate/templates/SlotDerivation.js @@ -1,4 +1,5 @@ const format = require('../format-lines'); +const sanitize = require('../helpers/sanitize'); const { TYPES } = require('./Slot.opts'); const header = `\ @@ -77,7 +78,7 @@ const mapping = ({ type }) => `\ */ function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) { assembly ("memory-safe") { - mstore(0x00, key) + mstore(0x00, ${(sanitize[type] ?? (x => x))('key')}) mstore(0x20, slot) result := keccak256(0x00, 0x40) } diff --git a/scripts/generate/templates/SlotDerivation.t.js b/scripts/generate/templates/SlotDerivation.t.js index dc7b07ff4..f03e1fc25 100644 --- a/scripts/generate/templates/SlotDerivation.t.js +++ b/scripts/generate/templates/SlotDerivation.t.js @@ -61,6 +61,18 @@ function testSymbolicDeriveMapping${name}(${type} key) public { } `; +const mappingDirty = ({ type, name }) => `\ +function testSymbolicDeriveMapping${name}Dirty(bytes32 dirtyKey) public { + ${type} key; + assembly { + key := dirtyKey + } + + // run the "normal" test using a potentially dirty value + testSymbolicDeriveMapping${name}(key); +} +`; + const boundedMapping = ({ type, name }) => `\ mapping(${type} => bytes) private _${type}Mapping; @@ -107,6 +119,8 @@ module.exports = format( })), ), ).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))), + mappingDirty(TYPES.bool), + mappingDirty(TYPES.address), ), ).trimEnd(), '}', diff --git a/test/utils/SlotDerivation.t.sol b/test/utils/SlotDerivation.t.sol index 300a58b5f..4021f0f87 100644 --- a/test/utils/SlotDerivation.t.sol +++ b/test/utils/SlotDerivation.t.sol @@ -225,4 +225,24 @@ contract SlotDerivationTest is Test, SymTest { assertEq(baseSlot.deriveMapping(key), derivedSlot); } + + function testSymbolicDeriveMappingBooleanDirty(bytes32 dirtyKey) public { + bool key; + assembly { + key := dirtyKey + } + + // run the "normal" test using a potentially dirty value + testSymbolicDeriveMappingBoolean(key); + } + + function testSymbolicDeriveMappingAddressDirty(bytes32 dirtyKey) public { + address key; + assembly { + key := dirtyKey + } + + // run the "normal" test using a potentially dirty value + testSymbolicDeriveMappingAddress(key); + } }