diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index c6b42dd82..2d5c15159 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -86,7 +86,7 @@ jobs: - run: rm foundry.toml - uses: crytic/slither-action@v0.3.0 with: - node-version: 18 + node-version: 18.15 codespell: if: github.repository != 'OpenZeppelin/openzeppelin-contracts-upgradeable' diff --git a/.github/workflows/formal-verification.yml b/.github/workflows/formal-verification.yml index 29c02541c..ae5eba006 100644 --- a/.github/workflows/formal-verification.yml +++ b/.github/workflows/formal-verification.yml @@ -1,10 +1,6 @@ name: formal verification on: - push: - branches: - - master - - release-v* pull_request: types: - opened @@ -33,8 +29,20 @@ jobs: if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'formal-verification') steps: - uses: actions/checkout@v3 + with: + fetch-depth: 0 - name: Set up environment uses: ./.github/actions/setup + - name: identify specs that need to be run + id: arguments + run: | + if [[ ${{ github.event_name }} = 'pull_request' ]]; + then + RESULT=$(git diff ${{ github.event.pull_request.head.sha }}..${{ github.event.pull_request.base.sha }} --name-only certora/specs/*.spec | while IFS= read -r file; do [[ -f $file ]] && basename "${file%.spec}"; done | tr "\n" " ") + else + RESULT='--all' + fi + echo "result=$RESULT" >> "$GITHUB_OUTPUT" - name: Install python uses: actions/setup-python@v4 with: @@ -55,6 +63,6 @@ jobs: - name: Verify specification run: | make -C certora apply - node certora/run.js >> "$GITHUB_STEP_SUMMARY" + node certora/run.js ${{ steps.arguments.outputs.result }} >> "$GITHUB_STEP_SUMMARY" env: CERTORAKEY: ${{ secrets.CERTORAKEY }} diff --git a/certora/harnesses/EnumerableMapHarness.sol b/certora/harnesses/EnumerableMapHarness.sol new file mode 100644 index 000000000..3bcf1b50a --- /dev/null +++ b/certora/harnesses/EnumerableMapHarness.sol @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import "../patched/utils/structs/EnumerableMap.sol"; + +contract EnumerableMapHarness { + using EnumerableMap for EnumerableMap.Bytes32ToBytes32Map; + + EnumerableMap.Bytes32ToBytes32Map private _map; + + function set(bytes32 key, bytes32 value) public returns (bool) { + return _map.set(key, value); + } + + function remove(bytes32 key) public returns (bool) { + return _map.remove(key); + } + + function contains(bytes32 key) public view returns (bool) { + return _map.contains(key); + } + + function length() public view returns (uint256) { + return _map.length(); + } + + function key_at(uint256 index) public view returns (bytes32) { + (bytes32 key,) = _map.at(index); + return key; + } + + function value_at(uint256 index) public view returns (bytes32) { + (,bytes32 value) = _map.at(index); + return value; + } + + function tryGet_contains(bytes32 key) public view returns (bool) { + (bool contained,) = _map.tryGet(key); + return contained; + } + + function tryGet_value(bytes32 key) public view returns (bytes32) { + (,bytes32 value) = _map.tryGet(key); + return value; + } + + function get(bytes32 key) public view returns (bytes32) { + return _map.get(key); + } + + function _indexOf(bytes32 key) public view returns (uint256) { + return _map._keys._inner._indexes[key]; + } +} diff --git a/certora/harnesses/EnumerableSetHarness.sol b/certora/harnesses/EnumerableSetHarness.sol new file mode 100644 index 000000000..64383e6a4 --- /dev/null +++ b/certora/harnesses/EnumerableSetHarness.sol @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import "../patched/utils/structs/EnumerableSet.sol"; + +contract EnumerableSetHarness { + using EnumerableSet for EnumerableSet.Bytes32Set; + + EnumerableSet.Bytes32Set private _set; + + function add(bytes32 value) public returns (bool) { + return _set.add(value); + } + + function remove(bytes32 value) public returns (bool) { + return _set.remove(value); + } + + function contains(bytes32 value) public view returns (bool) { + return _set.contains(value); + } + + function length() public view returns (uint256) { + return _set.length(); + } + + function at_(uint256 index) public view returns (bytes32) { + return _set.at(index); + } + + function _indexOf(bytes32 value) public view returns (uint256) { + return _set._inner._indexes[value]; + } +} diff --git a/certora/run.js b/certora/run.js old mode 100644 new mode 100755 index f6ee0cff4..9646c5056 --- a/certora/run.js +++ b/certora/run.js @@ -1,37 +1,65 @@ #!/usr/bin/env node // USAGE: -// node certora/run.js [[CONTRACT_NAME:]SPEC_NAME] [OPTIONS...] +// node certora/run.js [[CONTRACT_NAME:]SPEC_NAME]* [--all] [--options OPTIONS...] [--specs PATH] // EXAMPLES: +// node certora/run.js --all // node certora/run.js AccessControl // node certora/run.js AccessControlHarness:AccessControl -const MAX_PARALLEL = 4; - const proc = require('child_process'); const { PassThrough } = require('stream'); const events = require('events'); -const micromatch = require('micromatch'); -const limit = require('p-limit')(MAX_PARALLEL); -let [, , request = '', ...extraOptions] = process.argv; -if (request.startsWith('-')) { - extraOptions.unshift(request); - request = ''; -} -const [reqSpec, reqContract] = request.split(':').reverse(); +const argv = require('yargs') + .env('') + .options({ + all: { + alias: 'a', + type: 'boolean', + }, + spec: { + alias: 's', + type: 'string', + default: __dirname + '/specs.js', + }, + parallel: { + alias: 'p', + type: 'number', + default: 4, + }, + options: { + alias: 'o', + type: 'array', + default: [], + }, + }).argv; -const specs = require(__dirname + '/specs.js') - .filter(entry => !reqSpec || micromatch.isMatch(entry.spec, reqSpec)) - .filter(entry => !reqContract || micromatch.isMatch(entry.contract, reqContract)); - -if (specs.length === 0) { - console.error(`Error: Requested spec '${request}' not found in specs.json`); - process.exit(1); +function match(entry, request) { + const [reqSpec, reqContract] = request.split(':').reverse(); + return entry.spec == reqSpec && (!reqContract || entry.contract == reqContract); } -for (const { spec, contract, files, options = [] } of Object.values(specs)) { - limit(runCertora, spec, contract, files, [...options.flatMap(opt => opt.split(' ')), ...extraOptions]); +const specs = require(argv.spec).filter(s => argv.all || argv._.some(r => match(s, r))); +const limit = require('p-limit')(argv.parallel); + +if (argv._.length == 0 && !argv.all) { + console.error(`Warning: No specs requested. Did you forgot to toggle '--all'?`); +} + +for (const r of argv._) { + if (!specs.some(s => match(s, r))) { + console.error(`Error: Requested spec '${r}' not found in ${argv.spec}`); + process.exitCode = 1; + } +} + +if (process.exitCode) { + process.exit(process.exitCode); +} + +for (const { spec, contract, files, options = [] } of specs) { + limit(runCertora, spec, contract, files, [...options.flatMap(opt => opt.split(' ')), ...argv.options]); } // Run certora, aggregate the output and print it at the end diff --git a/certora/specs.js b/certora/specs.js index b5ed38fff..e30031d74 100644 --- a/certora/specs.js +++ b/certora/specs.js @@ -60,6 +60,16 @@ module.exports = [].concat( contract: 'DoubleEndedQueueHarness', files: ['certora/harnesses/DoubleEndedQueueHarness.sol'], }, + { + spec: 'EnumerableSet', + contract: 'EnumerableSetHarness', + files: ['certora/harnesses/EnumerableSetHarness.sol'], + }, + { + spec: 'EnumerableMap', + contract: 'EnumerableMapHarness', + files: ['certora/harnesses/EnumerableMapHarness.sol'], + }, // Governance { spec: 'TimelockController', diff --git a/certora/specs/EnumerableMap.spec b/certora/specs/EnumerableMap.spec new file mode 100644 index 000000000..56ef854c6 --- /dev/null +++ b/certora/specs/EnumerableMap.spec @@ -0,0 +1,334 @@ +import "helpers.spec" + +methods { + // library + set(bytes32,bytes32) returns (bool) envfree + remove(bytes32) returns (bool) envfree + contains(bytes32) returns (bool) envfree + length() returns (uint256) envfree + key_at(uint256) returns (bytes32) envfree + value_at(uint256) returns (bytes32) envfree + tryGet_contains(bytes32) returns (bool) envfree + tryGet_value(bytes32) returns (bytes32) envfree + get(bytes32) returns (bytes32) envfree + + // FV + _indexOf(bytes32) returns (uint256) envfree +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Helpers │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +function sanity() returns bool { + return length() < max_uint256; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: the value mapping is empty for keys that are not in the EnumerableMap. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant noValueIfNotContained(bytes32 key) + !contains(key) => tryGet_value(key) == 0 + { + preserved set(bytes32 otherKey, bytes32 someValue) { + require sanity(); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: All indexed keys are contained │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant indexedContained(uint256 index) + index < length() => contains(key_at(index)) + { + preserved { + requireInvariant consistencyIndex(index); + requireInvariant consistencyIndex(to_uint256(length() - 1)); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: A value can only be stored at a single location │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant atUniqueness(uint256 index1, uint256 index2) + index1 == index2 <=> key_at(index1) == key_at(index2) + { + preserved remove(bytes32 key) { + requireInvariant atUniqueness(index1, to_uint256(length() - 1)); + requireInvariant atUniqueness(index2, to_uint256(length() - 1)); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: index <> value relationship is consistent │ +│ │ +│ Note that the two consistencyXxx invariants, put together, prove that at_ and _indexOf are inverse of one another. │ +│ This proves that we have a bijection between indices (the enumerability part) and keys (the entries that are set │ +│ and removed from the EnumerableMap). │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant consistencyIndex(uint256 index) + index < length() => _indexOf(key_at(index)) == index + 1 + { + preserved remove(bytes32 key) { + requireInvariant consistencyIndex(to_uint256(length() - 1)); + } + } + +invariant consistencyKey(bytes32 key) + contains(key) => ( + _indexOf(key) > 0 && + _indexOf(key) <= length() && + key_at(to_uint256(_indexOf(key) - 1)) == key + ) + { + preserved remove(bytes32 otherKey) { + requireInvariant consistencyKey(otherKey); + requireInvariant atUniqueness( + to_uint256(_indexOf(key) - 1), + to_uint256(_indexOf(otherKey) - 1) + ); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: state only changes by setting or removing elements │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule stateChange(env e, bytes32 key) { + require sanity(); + requireInvariant consistencyKey(key); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + bytes32 valueBefore = tryGet_value(key); + + method f; + calldataarg args; + f(e, args); + + uint256 lengthAfter = length(); + bool containsAfter = contains(key); + bytes32 valueAfter = tryGet_value(key); + + assert lengthBefore != lengthAfter => ( + (f.selector == set(bytes32,bytes32).selector && lengthAfter == lengthBefore + 1) || + (f.selector == remove(bytes32).selector && lengthAfter == lengthBefore - 1) + ); + + assert containsBefore != containsAfter => ( + (f.selector == set(bytes32,bytes32).selector && containsAfter) || + (f.selector == remove(bytes32).selector && !containsAfter) + ); + + assert valueBefore != valueAfter => ( + (f.selector == set(bytes32,bytes32).selector && containsAfter) || + (f.selector == remove(bytes32).selector && !containsAfter && valueAfter == 0) + ); +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: check liveness of view functions. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule liveness_1(bytes32 key) { + requireInvariant consistencyKey(key); + + // contains never revert + bool contains = contains@withrevert(key); + assert !lastReverted; + + // tryGet never reverts (key) + tryGet_contains@withrevert(key); + assert !lastReverted; + + // tryGet never reverts (value) + tryGet_value@withrevert(key); + assert !lastReverted; + + // get reverts iff the key is not in the map + get@withrevert(key); + assert !lastReverted <=> contains; +} + +rule liveness_2(uint256 index) { + requireInvariant consistencyIndex(index); + + // length never revert + uint256 length = length@withrevert(); + assert !lastReverted; + + // key_at reverts iff the index is out of bound + key_at@withrevert(index); + assert !lastReverted <=> index < length; + + // value_at reverts iff the index is out of bound + value_at@withrevert(index); + assert !lastReverted <=> index < length; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: get and tryGet return the expected values. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule getAndTryGet(bytes32 key) { + requireInvariant noValueIfNotContained(key); + + bool contained = contains(key); + bool tryContained = tryGet_contains(key); + bytes32 tryValue = tryGet_value(key); + bytes32 value = get@withrevert(key); // revert is not contained + + assert contained == tryContained; + assert contained => tryValue == value; + assert !contained => tryValue == 0; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: set key-value in EnumerableMap │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule set(bytes32 key, bytes32 value, bytes32 otherKey) { + require sanity(); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + bool containsOtherBefore = contains(otherKey); + bytes32 otherValueBefore = tryGet_value(otherKey); + + bool added = set@withrevert(key, value); + bool success = !lastReverted; + + assert success && contains(key) && get(key) == value, + "liveness & immediate effect"; + + assert added <=> !containsBefore, + "return value: added iff not contained"; + + assert length() == lengthBefore + to_mathint(added ? 1 : 0), + "effect: length increases iff added"; + + assert added => (key_at(lengthBefore) == key && value_at(lengthBefore) == value), + "effect: add at the end"; + + assert containsOtherBefore != contains(otherKey) => (added && key == otherKey), + "side effect: other keys are not affected"; + + assert otherValueBefore != tryGet_value(otherKey) => key == otherKey, + "side effect: values attached to other keys are not affected"; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: remove key from EnumerableMap │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule remove(bytes32 key, bytes32 otherKey) { + requireInvariant consistencyKey(key); + requireInvariant consistencyKey(otherKey); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + bool containsOtherBefore = contains(otherKey); + bytes32 otherValueBefore = tryGet_value(otherKey); + + bool removed = remove@withrevert(key); + bool success = !lastReverted; + + assert success && !contains(key), + "liveness & immediate effect"; + + assert removed <=> containsBefore, + "return value: removed iff contained"; + + assert length() == lengthBefore - to_mathint(removed ? 1 : 0), + "effect: length decreases iff removed"; + + assert containsOtherBefore != contains(otherKey) => (removed && key == otherKey), + "side effect: other keys are not affected"; + + assert otherValueBefore != tryGet_value(otherKey) => key == otherKey, + "side effect: values attached to other keys are not affected"; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: when adding a new key, the other keys remain in set, at the same index. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule setEnumerability(bytes32 key, bytes32 value, uint256 index) { + require sanity(); + + bytes32 atKeyBefore = key_at(index); + bytes32 atValueBefore = value_at(index); + + set(key, value); + + bytes32 atKeyAfter = key_at@withrevert(index); + assert !lastReverted; + + bytes32 atValueAfter = value_at@withrevert(index); + assert !lastReverted; + + assert atKeyAfter == atKeyBefore; + assert atValueAfter != atValueBefore => ( + key == atKeyBefore && + value == atValueAfter + ); +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: when removing a existing key, the other keys remain in set, at the same index (except for the last one). │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule removeEnumerability(bytes32 key, uint256 index) { + uint256 last = length() - 1; + + requireInvariant consistencyKey(key); + requireInvariant consistencyIndex(index); + requireInvariant consistencyIndex(last); + + bytes32 atKeyBefore = key_at(index); + bytes32 atValueBefore = value_at(index); + bytes32 lastKeyBefore = key_at(last); + bytes32 lastValueBefore = value_at(last); + + remove(key); + + // can't read last value & keys (length decreased) + bytes32 atKeyAfter = key_at@withrevert(index); + assert lastReverted <=> index == last; + + bytes32 atValueAfter = value_at@withrevert(index); + assert lastReverted <=> index == last; + + // One value that is allowed to change is if previous value was removed, + // in that case the last value before took its place. + assert ( + index != last && + atKeyBefore != atKeyAfter + ) => ( + atKeyBefore == key && + atKeyAfter == lastKeyBefore + ); + + assert ( + index != last && + atValueBefore != atValueAfter + ) => ( + atValueAfter == lastValueBefore + ); +} diff --git a/certora/specs/EnumerableSet.spec b/certora/specs/EnumerableSet.spec new file mode 100644 index 000000000..c94ba2437 --- /dev/null +++ b/certora/specs/EnumerableSet.spec @@ -0,0 +1,247 @@ +import "helpers.spec" + +methods { + // library + add(bytes32) returns (bool) envfree + remove(bytes32) returns (bool) envfree + contains(bytes32) returns (bool) envfree + length() returns (uint256) envfree + at_(uint256) returns (bytes32) envfree + + // FV + _indexOf(bytes32) returns (uint256) envfree +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Helpers │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +function sanity() returns bool { + return length() < max_uint256; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: All indexed keys are contained │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant indexedContained(uint256 index) + index < length() => contains(at_(index)) + { + preserved { + requireInvariant consistencyIndex(index); + requireInvariant consistencyIndex(to_uint256(length() - 1)); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: A value can only be stored at a single location │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant atUniqueness(uint256 index1, uint256 index2) + index1 == index2 <=> at_(index1) == at_(index2) + { + preserved remove(bytes32 key) { + requireInvariant atUniqueness(index1, to_uint256(length() - 1)); + requireInvariant atUniqueness(index2, to_uint256(length() - 1)); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Invariant: index <> key relationship is consistent │ +│ │ +│ Note that the two consistencyXxx invariants, put together, prove that at_ and _indexOf are inverse of one another. │ +│ This proves that we have a bijection between indices (the enumerability part) and keys (the entries that are added │ +│ and removed from the EnumerableSet). │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +invariant consistencyIndex(uint256 index) + index < length() => _indexOf(at_(index)) == index + 1 + { + preserved remove(bytes32 key) { + requireInvariant consistencyIndex(to_uint256(length() - 1)); + } + } + +invariant consistencyKey(bytes32 key) + contains(key) => ( + _indexOf(key) > 0 && + _indexOf(key) <= length() && + at_(to_uint256(_indexOf(key) - 1)) == key + ) + { + preserved remove(bytes32 otherKey) { + requireInvariant consistencyKey(otherKey); + requireInvariant atUniqueness( + to_uint256(_indexOf(key) - 1), + to_uint256(_indexOf(otherKey) - 1) + ); + } + } + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: state only changes by adding or removing elements │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule stateChange(env e, bytes32 key) { + require sanity(); + requireInvariant consistencyKey(key); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + + method f; + calldataarg args; + f(e, args); + + uint256 lengthAfter = length(); + bool containsAfter = contains(key); + + assert lengthBefore != lengthAfter => ( + (f.selector == add(bytes32).selector && lengthAfter == lengthBefore + 1) || + (f.selector == remove(bytes32).selector && lengthAfter == lengthBefore - 1) + ); + + assert containsBefore != containsAfter => ( + (f.selector == add(bytes32).selector && containsAfter) || + (f.selector == remove(bytes32).selector && containsBefore) + ); +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: check liveness of view functions. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule liveness_1(bytes32 key) { + requireInvariant consistencyKey(key); + + // contains never revert + contains@withrevert(key); + assert !lastReverted; +} + +rule liveness_2(uint256 index) { + requireInvariant consistencyIndex(index); + + // length never revert + uint256 length = length@withrevert(); + assert !lastReverted; + + // at reverts iff the index is out of bound + at_@withrevert(index); + assert !lastReverted <=> index < length; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: add key to EnumerableSet if not already contained │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule add(bytes32 key, bytes32 otherKey) { + require sanity(); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + bool containsOtherBefore = contains(otherKey); + + bool added = add@withrevert(key); + bool success = !lastReverted; + + assert success && contains(key), + "liveness & immediate effect"; + + assert added <=> !containsBefore, + "return value: added iff not contained"; + + assert length() == lengthBefore + to_mathint(added ? 1 : 0), + "effect: length increases iff added"; + + assert added => at_(lengthBefore) == key, + "effect: add at the end"; + + assert containsOtherBefore != contains(otherKey) => (added && key == otherKey), + "side effect: other keys are not affected"; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: remove key from EnumerableSet if already contained │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule remove(bytes32 key, bytes32 otherKey) { + requireInvariant consistencyKey(key); + requireInvariant consistencyKey(otherKey); + + uint256 lengthBefore = length(); + bool containsBefore = contains(key); + bool containsOtherBefore = contains(otherKey); + + bool removed = remove@withrevert(key); + bool success = !lastReverted; + + assert success && !contains(key), + "liveness & immediate effect"; + + assert removed <=> containsBefore, + "return value: removed iff contained"; + + assert length() == lengthBefore - to_mathint(removed ? 1 : 0), + "effect: length decreases iff removed"; + + assert containsOtherBefore != contains(otherKey) => (removed && key == otherKey), + "side effect: other keys are not affected"; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: when adding a new key, the other keys remain in set, at the same index. │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule addEnumerability(bytes32 key, uint256 index) { + require sanity(); + + bytes32 atBefore = at_(index); + add(key); + bytes32 atAfter = at_@withrevert(index); + bool atAfterSuccess = !lastReverted; + + assert atAfterSuccess; + assert atBefore == atAfter; +} + +/* +┌─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ Rule: when removing a existing key, the other keys remain in set, at the same index (except for the last one). │ +└─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ +*/ +rule removeEnumerability(bytes32 key, uint256 index) { + uint256 last = length() - 1; + + requireInvariant consistencyKey(key); + requireInvariant consistencyIndex(index); + requireInvariant consistencyIndex(last); + + bytes32 atBefore = at_(index); + bytes32 lastBefore = at_(last); + + remove(key); + + // can't read last value (length decreased) + bytes32 atAfter = at_@withrevert(index); + assert lastReverted <=> index == last; + + // One value that is allowed to change is if previous value was removed, + // in that case the last value before took its place. + assert ( + index != last && + atBefore != atAfter + ) => ( + atBefore == key && + atAfter == lastBefore + ); +} diff --git a/certora/specs/Ownable.spec b/certora/specs/Ownable.spec index 48bd84d13..4fdfeb09c 100644 --- a/certora/specs/Ownable.spec +++ b/certora/specs/Ownable.spec @@ -62,10 +62,10 @@ rule onlyCurrentOwnerCanCallOnlyOwner(env e) { │ Rule: ownership can only change in specific ways │ └─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ */ -rule onlyOwnerOrPendingOwnerCanChangeOwnership(env e, method f) { +rule onlyOwnerOrPendingOwnerCanChangeOwnership(env e) { address oldCurrent = owner(); - calldataarg args; + method f; calldataarg args; f(e, args); address newCurrent = owner(); diff --git a/contracts/access/AccessControlDefaultAdminRules.sol b/contracts/access/AccessControlDefaultAdminRules.sol index 43fca9350..0c640fb99 100644 --- a/contracts/access/AccessControlDefaultAdminRules.sol +++ b/contracts/access/AccessControlDefaultAdminRules.sol @@ -136,7 +136,7 @@ abstract contract AccessControlDefaultAdminRules is IAccessControlDefaultAdminRu * @dev See {AccessControl-_revokeRole}. */ function _revokeRole(bytes32 role, address account) internal virtual override { - if (role == DEFAULT_ADMIN_ROLE) { + if (role == DEFAULT_ADMIN_ROLE && account == _currentDefaultAdmin) { delete _currentDefaultAdmin; } super._revokeRole(role, account); diff --git a/contracts/interfaces/IERC1967.sol b/contracts/interfaces/IERC1967.sol index e5deebee9..ab4450eec 100644 --- a/contracts/interfaces/IERC1967.sol +++ b/contracts/interfaces/IERC1967.sol @@ -5,7 +5,7 @@ pragma solidity ^0.8.0; /** * @dev ERC-1967: Proxy Storage Slots. This interface contains the events defined in the ERC. * - * _Available since v4.9._ + * _Available since v4.8.3._ */ interface IERC1967 { /** diff --git a/contracts/mocks/CallReceiverMock.sol b/contracts/mocks/CallReceiverMock.sol index 344a1054b..492adbe92 100644 --- a/contracts/mocks/CallReceiverMock.sol +++ b/contracts/mocks/CallReceiverMock.sol @@ -14,6 +14,10 @@ contract CallReceiverMock { return "0x1234"; } + function mockFunctionEmptyReturn() public payable { + emit MockFunctionCalled(); + } + function mockFunctionWithArgs(uint256 a, uint256 b) public payable returns (string memory) { emit MockFunctionCalledWithArgs(a, b); diff --git a/contracts/mocks/ERC20Reentrant.sol b/contracts/mocks/ERC20Reentrant.sol new file mode 100644 index 000000000..c0184b77b --- /dev/null +++ b/contracts/mocks/ERC20Reentrant.sol @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "../token/ERC20/ERC20.sol"; +import "../token/ERC20/extensions/ERC4626.sol"; + +contract ERC20Reentrant is ERC20("TEST", "TST") { + enum Type { + No, + Before, + After + } + + Type private _reenterType; + address private _reenterTarget; + bytes private _reenterData; + + function scheduleReenter(Type when, address target, bytes calldata data) external { + _reenterType = when; + _reenterTarget = target; + _reenterData = data; + } + + function functionCall(address target, bytes memory data) public returns (bytes memory) { + return Address.functionCall(target, data); + } + + function _beforeTokenTransfer(address from, address to, uint256 amount) internal override { + if (_reenterType == Type.Before) { + _reenterType = Type.No; + functionCall(_reenterTarget, _reenterData); + } + super._beforeTokenTransfer(from, to, amount); + } + + function _afterTokenTransfer(address from, address to, uint256 amount) internal override { + super._afterTokenTransfer(from, to, amount); + if (_reenterType == Type.After) { + _reenterType = Type.No; + functionCall(_reenterTarget, _reenterData); + } + } +} diff --git a/contracts/mocks/TimelockReentrant.sol b/contracts/mocks/TimelockReentrant.sol new file mode 100644 index 000000000..a9344f50d --- /dev/null +++ b/contracts/mocks/TimelockReentrant.sol @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "../utils/Address.sol"; + +contract TimelockReentrant { + address private _reenterTarget; + bytes private _reenterData; + bool _reentered; + + function disableReentrancy() external { + _reentered = true; + } + + function enableRentrancy(address target, bytes calldata data) external { + _reenterTarget = target; + _reenterData = data; + } + + function reenter() external { + if (!_reentered) { + _reentered = true; + Address.functionCall(_reenterTarget, _reenterData); + } + } +} diff --git a/contracts/utils/Address.sol b/contracts/utils/Address.sol index 433a866d7..5ff14140a 100644 --- a/contracts/utils/Address.sol +++ b/contracts/utils/Address.sol @@ -59,7 +59,7 @@ library Address { * IMPORTANT: because control is transferred to `recipient`, care must be * taken to not create reentrancy vulnerabilities. Consider using * {ReentrancyGuard} or the - * https://solidity.readthedocs.io/en/v0.5.11/security-considerations.html#use-the-checks-effects-interactions-pattern[checks-effects-interactions pattern]. + * https://solidity.readthedocs.io/en/v0.8.0/security-considerations.html#use-the-checks-effects-interactions-pattern[checks-effects-interactions pattern]. */ function sendValue(address payable recipient, uint256 amount) internal { require(address(this).balance >= amount, "Address: insufficient balance"); diff --git a/scripts/generate/templates/Checkpoints.t.js b/scripts/generate/templates/Checkpoints.t.js index 84b5992ad..b3da933a1 100644 --- a/scripts/generate/templates/Checkpoints.t.js +++ b/scripts/generate/templates/Checkpoints.t.js @@ -86,10 +86,12 @@ function testPush( if (keys.length > 0) { ${opts.keyTypeName} lastKey = keys[keys.length - 1]; - pastKey = _bound${capitalize(opts.keyTypeName)}(pastKey, 0, lastKey - 1); - - vm.expectRevert(); - this.push(pastKey, values[keys.length % values.length]); + if (lastKey > 0) { + pastKey = _bound${capitalize(opts.keyTypeName)}(pastKey, 0, lastKey - 1); + + vm.expectRevert(); + this.push(pastKey, values[keys.length % values.length]); + } } } @@ -173,11 +175,13 @@ function testPush( // Can't push any key in the past if (keys.length > 0) { ${opts.keyTypeName} lastKey = keys[keys.length - 1]; - pastKey = _bound${capitalize(opts.keyTypeName)}(pastKey, 0, lastKey - 1); - - vm.roll(pastKey); - vm.expectRevert(); - this.push(values[keys.length % values.length]); + if (lastKey > 0) { + pastKey = _bound${capitalize(opts.keyTypeName)}(pastKey, 0, lastKey - 1); + + vm.roll(pastKey); + vm.expectRevert(); + this.push(values[keys.length % values.length]); + } } } diff --git a/test/access/AccessControl.behavior.js b/test/access/AccessControl.behavior.js index 6c88aa274..49ab44b58 100644 --- a/test/access/AccessControl.behavior.js +++ b/test/access/AccessControl.behavior.js @@ -513,15 +513,12 @@ function shouldBehaveLikeAccessControlDefaultAdminRules(errorPrefix, delay, defa }); describe('when caller is pending default admin and delay has passed', function () { - let from; - beforeEach(async function () { await time.setNextBlockTimestamp(acceptSchedule.addn(1)); - from = newDefaultAdmin; }); it('accepts a transfer and changes default admin', async function () { - const receipt = await this.accessControl.acceptDefaultAdminTransfer({ from }); + const receipt = await this.accessControl.acceptDefaultAdminTransfer({ from: newDefaultAdmin }); // Storage changes expect(await this.accessControl.hasRole(DEFAULT_ADMIN_ROLE, defaultAdmin)).to.be.false; @@ -625,10 +622,9 @@ function shouldBehaveLikeAccessControlDefaultAdminRules(errorPrefix, delay, defa describe('renounces admin', function () { let delayPassed; - let from = defaultAdmin; beforeEach(async function () { - await this.accessControl.beginDefaultAdminTransfer(constants.ZERO_ADDRESS, { from }); + await this.accessControl.beginDefaultAdminTransfer(constants.ZERO_ADDRESS, { from: defaultAdmin }); delayPassed = web3.utils .toBN(await time.latest()) .add(delay) @@ -638,27 +634,37 @@ function shouldBehaveLikeAccessControlDefaultAdminRules(errorPrefix, delay, defa it('reverts if caller is not default admin', async function () { await time.setNextBlockTimestamp(delayPassed); await expectRevert( - this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, other, { from }), + this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, other, { from: defaultAdmin }), `${errorPrefix}: can only renounce roles for self`, ); }); + it('keeps defaultAdmin consistent with hasRole if another non-defaultAdmin user renounces the DEFAULT_ADMIN_ROLE', async function () { + await time.setNextBlockTimestamp(delayPassed); + + // This passes because it's a noop + await this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, other, { from: other }); + + expect(await this.accessControl.hasRole(DEFAULT_ADMIN_ROLE, defaultAdmin)).to.be.true; + expect(await this.accessControl.defaultAdmin()).to.be.equal(defaultAdmin); + }); + it('renounces role', async function () { await time.setNextBlockTimestamp(delayPassed); - const receipt = await this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, from, { from }); + const receipt = await this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, defaultAdmin, { from: defaultAdmin }); expect(await this.accessControl.hasRole(DEFAULT_ADMIN_ROLE, defaultAdmin)).to.be.false; - expect(await this.accessControl.hasRole(constants.ZERO_ADDRESS, defaultAdmin)).to.be.false; + expect(await this.accessControl.defaultAdmin()).to.be.equal(constants.ZERO_ADDRESS); expectEvent(receipt, 'RoleRevoked', { role: DEFAULT_ADMIN_ROLE, - account: from, + account: defaultAdmin, }); expect(await this.accessControl.owner()).to.equal(constants.ZERO_ADDRESS); }); it('allows to recover access using the internal _grantRole', async function () { await time.setNextBlockTimestamp(delayPassed); - await this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, from, { from }); + await this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, defaultAdmin, { from: defaultAdmin }); const grantRoleReceipt = await this.accessControl.$_grantRole(DEFAULT_ADMIN_ROLE, other); expectEvent(grantRoleReceipt, 'RoleGranted', { @@ -681,7 +687,7 @@ function shouldBehaveLikeAccessControlDefaultAdminRules(errorPrefix, delay, defa it(`reverts if block.timestamp is ${tag} to schedule`, async function () { await time.setNextBlockTimestamp(delayNotPassed.toNumber() + fromSchedule); await expectRevert( - this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, defaultAdmin, { from }), + this.accessControl.renounceRole(DEFAULT_ADMIN_ROLE, defaultAdmin, { from: defaultAdmin }), `${errorPrefix}: only can renounce in two delayed steps`, ); }); diff --git a/test/governance/TimelockController.test.js b/test/governance/TimelockController.test.js index dde923564..82a84746e 100644 --- a/test/governance/TimelockController.test.js +++ b/test/governance/TimelockController.test.js @@ -10,6 +10,7 @@ const CallReceiverMock = artifacts.require('CallReceiverMock'); const Implementation2 = artifacts.require('Implementation2'); const ERC721 = artifacts.require('$ERC721'); const ERC1155 = artifacts.require('$ERC1155'); +const TimelockReentrant = artifacts.require('$TimelockReentrant'); const MINDELAY = time.duration.days(1); @@ -345,6 +346,82 @@ contract('TimelockController', function (accounts) { `AccessControl: account ${other.toLowerCase()} is missing role ${EXECUTOR_ROLE}`, ); }); + + it('prevents reentrancy execution', async function () { + // Create operation + const reentrant = await TimelockReentrant.new(); + const reentrantOperation = genOperation( + reentrant.address, + 0, + reentrant.contract.methods.reenter().encodeABI(), + ZERO_BYTES32, + salt, + ); + + // Schedule so it can be executed + await this.mock.schedule( + reentrantOperation.target, + reentrantOperation.value, + reentrantOperation.data, + reentrantOperation.predecessor, + reentrantOperation.salt, + MINDELAY, + { from: proposer }, + ); + + // Advance on time to make the operation executable + const timestamp = await this.mock.getTimestamp(reentrantOperation.id); + await time.increaseTo(timestamp); + + // Grant executor role to the reentrant contract + await this.mock.grantRole(EXECUTOR_ROLE, reentrant.address, { from: admin }); + + // Prepare reenter + const data = this.mock.contract.methods + .execute( + reentrantOperation.target, + reentrantOperation.value, + reentrantOperation.data, + reentrantOperation.predecessor, + reentrantOperation.salt, + ) + .encodeABI(); + await reentrant.enableRentrancy(this.mock.address, data); + + // Expect to fail + await expectRevert( + this.mock.execute( + reentrantOperation.target, + reentrantOperation.value, + reentrantOperation.data, + reentrantOperation.predecessor, + reentrantOperation.salt, + { from: executor }, + ), + 'TimelockController: operation is not ready', + ); + + // Disable reentrancy + await reentrant.disableReentrancy(); + const nonReentrantOperation = reentrantOperation; // Not anymore + + // Try again successfully + const receipt = await this.mock.execute( + nonReentrantOperation.target, + nonReentrantOperation.value, + nonReentrantOperation.data, + nonReentrantOperation.predecessor, + nonReentrantOperation.salt, + { from: executor }, + ); + expectEvent(receipt, 'CallExecuted', { + id: nonReentrantOperation.id, + index: web3.utils.toBN(0), + target: nonReentrantOperation.target, + value: web3.utils.toBN(nonReentrantOperation.value), + data: nonReentrantOperation.data, + }); + }); }); }); }); @@ -632,6 +709,84 @@ contract('TimelockController', function (accounts) { 'TimelockController: length mismatch', ); }); + + it('prevents reentrancy execution', async function () { + // Create operation + const reentrant = await TimelockReentrant.new(); + const reentrantBatchOperation = genOperationBatch( + [reentrant.address], + [0], + [reentrant.contract.methods.reenter().encodeABI()], + ZERO_BYTES32, + salt, + ); + + // Schedule so it can be executed + await this.mock.scheduleBatch( + reentrantBatchOperation.targets, + reentrantBatchOperation.values, + reentrantBatchOperation.payloads, + reentrantBatchOperation.predecessor, + reentrantBatchOperation.salt, + MINDELAY, + { from: proposer }, + ); + + // Advance on time to make the operation executable + const timestamp = await this.mock.getTimestamp(reentrantBatchOperation.id); + await time.increaseTo(timestamp); + + // Grant executor role to the reentrant contract + await this.mock.grantRole(EXECUTOR_ROLE, reentrant.address, { from: admin }); + + // Prepare reenter + const data = this.mock.contract.methods + .executeBatch( + reentrantBatchOperation.targets, + reentrantBatchOperation.values, + reentrantBatchOperation.payloads, + reentrantBatchOperation.predecessor, + reentrantBatchOperation.salt, + ) + .encodeABI(); + await reentrant.enableRentrancy(this.mock.address, data); + + // Expect to fail + await expectRevert( + this.mock.executeBatch( + reentrantBatchOperation.targets, + reentrantBatchOperation.values, + reentrantBatchOperation.payloads, + reentrantBatchOperation.predecessor, + reentrantBatchOperation.salt, + { from: executor }, + ), + 'TimelockController: operation is not ready', + ); + + // Disable reentrancy + await reentrant.disableReentrancy(); + const nonReentrantBatchOperation = reentrantBatchOperation; // Not anymore + + // Try again successfully + const receipt = await this.mock.executeBatch( + nonReentrantBatchOperation.targets, + nonReentrantBatchOperation.values, + nonReentrantBatchOperation.payloads, + nonReentrantBatchOperation.predecessor, + nonReentrantBatchOperation.salt, + { from: executor }, + ); + for (const i in nonReentrantBatchOperation.targets) { + expectEvent(receipt, 'CallExecuted', { + id: nonReentrantBatchOperation.id, + index: web3.utils.toBN(i), + target: nonReentrantBatchOperation.targets[i], + value: web3.utils.toBN(nonReentrantBatchOperation.values[i]), + data: nonReentrantBatchOperation.payloads[i], + }); + } + }); }); }); diff --git a/test/token/ERC20/extensions/ERC4626.test.js b/test/token/ERC20/extensions/ERC4626.test.js index ee0998717..55b3e5d20 100644 --- a/test/token/ERC20/extensions/ERC4626.test.js +++ b/test/token/ERC20/extensions/ERC4626.test.js @@ -1,11 +1,14 @@ const { constants, expectEvent, expectRevert } = require('@openzeppelin/test-helpers'); const { expect } = require('chai'); +const { Enum } = require('../../../helpers/enums'); + const ERC20Decimals = artifacts.require('$ERC20DecimalsMock'); const ERC4626 = artifacts.require('$ERC4626'); const ERC4626OffsetMock = artifacts.require('$ERC4626OffsetMock'); const ERC4626FeesMock = artifacts.require('$ERC4626FeesMock'); const ERC20ExcessDecimalsMock = artifacts.require('ERC20ExcessDecimalsMock'); +const ERC20Reentrant = artifacts.require('$ERC20Reentrant'); contract('ERC4626', function (accounts) { const [holder, recipient, spender, other, user1, user2] = accounts; @@ -44,6 +47,178 @@ contract('ERC4626', function (accounts) { } }); + describe('reentrancy', async function () { + const reenterType = Enum('No', 'Before', 'After'); + + const amount = web3.utils.toBN(1000000000000000000); + const reenterAmount = web3.utils.toBN(1000000000); + let token; + let vault; + + beforeEach(async function () { + token = await ERC20Reentrant.new(); + // Use offset 1 so the rate is not 1:1 and we can't possibly confuse assets and shares + vault = await ERC4626OffsetMock.new('', '', token.address, 1); + // Funds and approval for tests + await token.$_mint(holder, amount); + await token.$_mint(other, amount); + await token.$_approve(holder, vault.address, constants.MAX_UINT256); + await token.$_approve(other, vault.address, constants.MAX_UINT256); + await token.$_approve(token.address, vault.address, constants.MAX_UINT256); + }); + + // During a `_deposit`, the vault does `transferFrom(depositor, vault, assets)` -> `_mint(receiver, shares)` + // such that a reentrancy BEFORE the transfer guarantees the price is kept the same. + // If the order of transfer -> mint is changed to mint -> transfer, the reentrancy could be triggered on an + // intermediate state in which the ratio of assets/shares has been decreased (more shares than assets). + it('correct share price is observed during reentrancy before deposit', async function () { + // mint token for deposit + await token.$_mint(token.address, reenterAmount); + + // Schedules a reentrancy from the token contract + await token.scheduleReenter( + reenterType.Before, + vault.address, + vault.contract.methods.deposit(reenterAmount, holder).encodeABI(), + ); + + // Initial share price + const sharesForDeposit = await vault.previewDeposit(amount, { from: holder }); + const sharesForReenter = await vault.previewDeposit(reenterAmount, { from: holder }); + + // Do deposit normally, triggering the _beforeTokenTransfer hook + const receipt = await vault.deposit(amount, holder, { from: holder }); + + // Main deposit event + await expectEvent(receipt, 'Deposit', { + sender: holder, + owner: holder, + assets: amount, + shares: sharesForDeposit, + }); + // Reentrant deposit event → uses the same price + await expectEvent(receipt, 'Deposit', { + sender: token.address, + owner: holder, + assets: reenterAmount, + shares: sharesForReenter, + }); + + // Assert prices is kept + const sharesAfter = await vault.previewDeposit(amount, { from: holder }); + expect(sharesForDeposit).to.be.bignumber.eq(sharesAfter); + }); + + // During a `_withdraw`, the vault does `_burn(owner, shares)` -> `transfer(receiver, assets)` + // such that a reentrancy AFTER the transfer guarantees the price is kept the same. + // If the order of burn -> transfer is changed to transfer -> burn, the reentrancy could be triggered on an + // intermediate state in which the ratio of shares/assets has been decreased (more assets than shares). + it('correct share price is observed during reentrancy after withdraw', async function () { + // Deposit into the vault: holder gets `amount` share, token.address gets `reenterAmount` shares + await vault.deposit(amount, holder, { from: holder }); + await vault.deposit(reenterAmount, token.address, { from: other }); + + // Schedules a reentrancy from the token contract + await token.scheduleReenter( + reenterType.After, + vault.address, + vault.contract.methods.withdraw(reenterAmount, holder, token.address).encodeABI(), + ); + + // Initial share price + const sharesForWithdraw = await vault.previewWithdraw(amount, { from: holder }); + const sharesForReenter = await vault.previewWithdraw(reenterAmount, { from: holder }); + + // Do withdraw normally, triggering the _afterTokenTransfer hook + const receipt = await vault.withdraw(amount, holder, holder, { from: holder }); + + // Main withdraw event + await expectEvent(receipt, 'Withdraw', { + sender: holder, + receiver: holder, + owner: holder, + assets: amount, + shares: sharesForWithdraw, + }); + // Reentrant withdraw event → uses the same price + await expectEvent(receipt, 'Withdraw', { + sender: token.address, + receiver: holder, + owner: token.address, + assets: reenterAmount, + shares: sharesForReenter, + }); + + // Assert price is kept + const sharesAfter = await vault.previewWithdraw(amount, { from: holder }); + expect(sharesForWithdraw).to.be.bignumber.eq(sharesAfter); + }); + + // Donate newly minted tokens to the vault during the reentracy causes the share price to increase. + // Still, the deposit that trigger the reentracy is not affected and get the previewed price. + // Further deposits will get a different price (getting fewer shares for the same amount of assets) + it('share price change during reentracy does not affect deposit', async function () { + // Schedules a reentrancy from the token contract that mess up the share price + await token.scheduleReenter( + reenterType.Before, + token.address, + token.contract.methods.$_mint(vault.address, reenterAmount).encodeABI(), + ); + + // Price before + const sharesBefore = await vault.previewDeposit(amount); + + // Deposit, triggering the _beforeTokenTransfer hook + const receipt = await vault.deposit(amount, holder, { from: holder }); + + // Price is as previewed + await expectEvent(receipt, 'Deposit', { + sender: holder, + owner: holder, + assets: amount, + shares: sharesBefore, + }); + + // Price was modified during reentrancy + const sharesAfter = await vault.previewDeposit(amount); + expect(sharesAfter).to.be.bignumber.lt(sharesBefore); + }); + + // Burn some tokens from the vault during the reentracy causes the share price to drop. + // Still, the withdraw that trigger the reentracy is not affected and get the previewed price. + // Further withdraw will get a different price (needing more shares for the same amount of assets) + it('share price change during reentracy does not affect withdraw', async function () { + await vault.deposit(amount, other, { from: other }); + await vault.deposit(amount, holder, { from: holder }); + + // Schedules a reentrancy from the token contract that mess up the share price + await token.scheduleReenter( + reenterType.After, + token.address, + token.contract.methods.$_burn(vault.address, reenterAmount).encodeABI(), + ); + + // Price before + const sharesBefore = await vault.previewWithdraw(amount); + + // Withdraw, triggering the _afterTokenTransfer hook + const receipt = await vault.withdraw(amount, holder, holder, { from: holder }); + + // Price is as previewed + await expectEvent(receipt, 'Withdraw', { + sender: holder, + receiver: holder, + owner: holder, + assets: amount, + shares: sharesBefore, + }); + + // Price was modified during reentrancy + const sharesAfter = await vault.previewWithdraw(amount); + expect(sharesAfter).to.be.bignumber.gt(sharesBefore); + }); + }); + for (const offset of [0, 6, 18].map(web3.utils.toBN)) { const parseToken = token => web3.utils.toBN(10).pow(decimals).muln(token); const parseShare = share => web3.utils.toBN(10).pow(decimals.add(offset)).muln(share); diff --git a/test/utils/Address.test.js b/test/utils/Address.test.js index a78ae14e6..4f9f9eea1 100644 --- a/test/utils/Address.test.js +++ b/test/utils/Address.test.js @@ -107,6 +107,14 @@ contract('Address', function (accounts) { await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); }); + it('calls the requested empty return function', async function () { + const abiEncodedCall = this.target.contract.methods.mockFunctionEmptyReturn().encodeABI(); + + const receipt = await this.mock.$functionCall(this.target.address, abiEncodedCall); + + await expectEvent.inTransaction(receipt.tx, CallReceiverMock, 'MockFunctionCalled'); + }); + it('reverts when the called function reverts with no reason', async function () { const abiEncodedCall = this.target.contract.methods.mockFunctionRevertsNoReason().encodeABI(); @@ -137,6 +145,11 @@ contract('Address', function (accounts) { await expectRevert.unspecified(this.mock.$functionCall(this.target.address, abiEncodedCall)); }); + it('bubbles up error message if specified', async function () { + const errorMsg = 'Address: expected error'; + await expectRevert(this.mock.$functionCall(this.target.address, '0x12345678', errorMsg), errorMsg); + }); + it('reverts when function does not exist', async function () { const abiEncodedCall = web3.eth.abi.encodeFunctionCall( { @@ -237,6 +250,11 @@ contract('Address', function (accounts) { 'Address: low-level call with value failed', ); }); + + it('bubbles up error message if specified', async function () { + const errorMsg = 'Address: expected error'; + await expectRevert(this.mock.$functionCallWithValue(this.target.address, '0x12345678', 0, errorMsg), errorMsg); + }); }); }); @@ -277,6 +295,11 @@ contract('Address', function (accounts) { await expectRevert(this.mock.$functionStaticCall(recipient, abiEncodedCall), 'Address: call to non-contract'); }); + + it('bubbles up error message if specified', async function () { + const errorMsg = 'Address: expected error'; + await expectRevert(this.mock.$functionCallWithValue(this.target.address, '0x12345678', 0, errorMsg), errorMsg); + }); }); describe('functionDelegateCall', function () { @@ -317,5 +340,22 @@ contract('Address', function (accounts) { await expectRevert(this.mock.$functionDelegateCall(recipient, abiEncodedCall), 'Address: call to non-contract'); }); + + it('bubbles up error message if specified', async function () { + const errorMsg = 'Address: expected error'; + await expectRevert(this.mock.$functionCallWithValue(this.target.address, '0x12345678', 0, errorMsg), errorMsg); + }); + }); + + describe('verifyCallResult', function () { + it('returns returndata on success', async function () { + const returndata = '0x123abc'; + expect(await this.mock.$verifyCallResult(true, returndata, '')).to.equal(returndata); + }); + + it('reverts with return data and error m', async function () { + const errorMsg = 'Address: expected error'; + await expectRevert(this.mock.$verifyCallResult(false, '0x', errorMsg), errorMsg); + }); }); }); diff --git a/test/utils/Checkpoints.t.sol b/test/utils/Checkpoints.t.sol index abdf8b436..f2cb587d5 100644 --- a/test/utils/Checkpoints.t.sol +++ b/test/utils/Checkpoints.t.sol @@ -66,11 +66,13 @@ contract CheckpointsHistoryTest is Test { // Can't push any key in the past if (keys.length > 0) { uint32 lastKey = keys[keys.length - 1]; - pastKey = _boundUint32(pastKey, 0, lastKey - 1); + if (lastKey > 0) { + pastKey = _boundUint32(pastKey, 0, lastKey - 1); - vm.roll(pastKey); - vm.expectRevert(); - this.push(values[keys.length % values.length]); + vm.roll(pastKey); + vm.expectRevert(); + this.push(values[keys.length % values.length]); + } } } @@ -185,10 +187,12 @@ contract CheckpointsTrace224Test is Test { if (keys.length > 0) { uint32 lastKey = keys[keys.length - 1]; - pastKey = _boundUint32(pastKey, 0, lastKey - 1); + if (lastKey > 0) { + pastKey = _boundUint32(pastKey, 0, lastKey - 1); - vm.expectRevert(); - this.push(pastKey, values[keys.length % values.length]); + vm.expectRevert(); + this.push(pastKey, values[keys.length % values.length]); + } } } @@ -291,10 +295,12 @@ contract CheckpointsTrace160Test is Test { if (keys.length > 0) { uint96 lastKey = keys[keys.length - 1]; - pastKey = _boundUint96(pastKey, 0, lastKey - 1); + if (lastKey > 0) { + pastKey = _boundUint96(pastKey, 0, lastKey - 1); - vm.expectRevert(); - this.push(pastKey, values[keys.length % values.length]); + vm.expectRevert(); + this.push(pastKey, values[keys.length % values.length]); + } } }