Clean dirty addresses and booleans (#5195)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
cairo
2024-09-18 18:21:17 +02:00
committed by GitHub
parent 809ded806f
commit 3f901696f7
7 changed files with 50 additions and 7 deletions

View File

@ -70,7 +70,7 @@ library SlotDerivation {
*/ */
function deriveMapping(bytes32 slot, address key) internal pure returns (bytes32 result) { function deriveMapping(bytes32 slot, address key) internal pure returns (bytes32 result) {
assembly ("memory-safe") { assembly ("memory-safe") {
mstore(0x00, key) mstore(0x00, and(key, shr(96, not(0))))
mstore(0x20, slot) mstore(0x20, slot)
result := keccak256(0x00, 0x40) result := keccak256(0x00, 0x40)
} }
@ -81,7 +81,7 @@ library SlotDerivation {
*/ */
function deriveMapping(bytes32 slot, bool key) internal pure returns (bytes32 result) { function deriveMapping(bytes32 slot, bool key) internal pure returns (bytes32 result) {
assembly ("memory-safe") { assembly ("memory-safe") {
mstore(0x00, key) mstore(0x00, iszero(iszero(key)))
mstore(0x20, slot) mstore(0x20, slot)
result := keccak256(0x00, 0x40) result := keccak256(0x00, 0x40)
} }

View File

@ -0,0 +1,5 @@
module.exports = {
address: expr => `and(${expr}, shr(96, not(0)))`,
bool: expr => `iszero(iszero(${expr}))`,
bytes: (expr, size) => `and(${expr}, shl(${256 - 8 * size}, not(0)))`,
};

View File

@ -1,4 +1,5 @@
const format = require('../format-lines'); const format = require('../format-lines');
const sanitize = require('../helpers/sanitize');
const { product } = require('../../helpers'); const { product } = require('../../helpers');
const { SIZES } = require('./Packing.opts'); const { SIZES } = require('./Packing.opts');
@ -44,8 +45,8 @@ function pack_${left}_${right}(bytes${left} left, bytes${right} right) internal
left + right left + right
} result) { } result) {
assembly ("memory-safe") { assembly ("memory-safe") {
left := and(left, shl(${256 - 8 * left}, not(0))) left := ${sanitize.bytes('left', left)}
right := and(right, shl(${256 - 8 * right}, not(0))) right := ${sanitize.bytes('right', right)}
result := or(left, shr(${8 * left}, right)) result := or(left, shr(${8 * left}, right))
} }
} }
@ -55,7 +56,7 @@ const extract = (outer, inner) => `\
function extract_${outer}_${inner}(bytes${outer} self, uint8 offset) internal pure returns (bytes${inner} result) { function extract_${outer}_${inner}(bytes${outer} self, uint8 offset) internal pure returns (bytes${inner} result) {
if (offset > ${outer - inner}) revert OutOfRangeAccess(); if (offset > ${outer - inner}) revert OutOfRangeAccess();
assembly ("memory-safe") { assembly ("memory-safe") {
result := and(shl(mul(8, offset), self), shl(${256 - 8 * inner}, not(0))) result := ${sanitize.bytes('shl(mul(8, offset), self)', inner)}
} }
} }
`; `;
@ -64,7 +65,7 @@ const replace = (outer, inner) => `\
function replace_${outer}_${inner}(bytes${outer} self, bytes${inner} value, uint8 offset) internal pure returns (bytes${outer} result) { function replace_${outer}_${inner}(bytes${outer} self, bytes${inner} value, uint8 offset) internal pure returns (bytes${outer} result) {
bytes${inner} oldValue = extract_${outer}_${inner}(self, offset); bytes${inner} oldValue = extract_${outer}_${inner}(self, offset);
assembly ("memory-safe") { assembly ("memory-safe") {
value := and(value, shl(${256 - 8 * inner}, not(0))) value := ${sanitize.bytes('value', inner)}
result := xor(self, shr(mul(8, offset), xor(oldValue, value))) result := xor(self, shr(mul(8, offset), xor(oldValue, value)))
} }
} }

View File

@ -10,4 +10,6 @@ const TYPES = [
{ type: 'bytes', isValueType: false }, { type: 'bytes', isValueType: false },
].map(type => Object.assign(type, { name: type.name ?? capitalize(type.type) })); ].map(type => Object.assign(type, { name: type.name ?? capitalize(type.type) }));
Object.assign(TYPES, Object.fromEntries(TYPES.map(entry => [entry.type, entry])));
module.exports = { TYPES }; module.exports = { TYPES };

View File

@ -1,4 +1,5 @@
const format = require('../format-lines'); const format = require('../format-lines');
const sanitize = require('../helpers/sanitize');
const { TYPES } = require('./Slot.opts'); const { TYPES } = require('./Slot.opts');
const header = `\ const header = `\
@ -77,7 +78,7 @@ const mapping = ({ type }) => `\
*/ */
function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) { function deriveMapping(bytes32 slot, ${type} key) internal pure returns (bytes32 result) {
assembly ("memory-safe") { assembly ("memory-safe") {
mstore(0x00, key) mstore(0x00, ${(sanitize[type] ?? (x => x))('key')})
mstore(0x20, slot) mstore(0x20, slot)
result := keccak256(0x00, 0x40) result := keccak256(0x00, 0x40)
} }

View File

@ -61,6 +61,18 @@ function testSymbolicDeriveMapping${name}(${type} key) public {
} }
`; `;
const mappingDirty = ({ type, name }) => `\
function testSymbolicDeriveMapping${name}Dirty(bytes32 dirtyKey) public {
${type} key;
assembly {
key := dirtyKey
}
// run the "normal" test using a potentially dirty value
testSymbolicDeriveMapping${name}(key);
}
`;
const boundedMapping = ({ type, name }) => `\ const boundedMapping = ({ type, name }) => `\
mapping(${type} => bytes) private _${type}Mapping; mapping(${type} => bytes) private _${type}Mapping;
@ -107,6 +119,8 @@ module.exports = format(
})), })),
), ),
).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))), ).map(type => (type.isValueType ? mapping(type) : boundedMapping(type))),
mappingDirty(TYPES.bool),
mappingDirty(TYPES.address),
), ),
).trimEnd(), ).trimEnd(),
'}', '}',

View File

@ -225,4 +225,24 @@ contract SlotDerivationTest is Test, SymTest {
assertEq(baseSlot.deriveMapping(key), derivedSlot); assertEq(baseSlot.deriveMapping(key), derivedSlot);
} }
function testSymbolicDeriveMappingBooleanDirty(bytes32 dirtyKey) public {
bool key;
assembly {
key := dirtyKey
}
// run the "normal" test using a potentially dirty value
testSymbolicDeriveMappingBoolean(key);
}
function testSymbolicDeriveMappingAddressDirty(bytes32 dirtyKey) public {
address key;
assembly {
key := dirtyKey
}
// run the "normal" test using a potentially dirty value
testSymbolicDeriveMappingAddress(key);
}
} }