From b6e07917eb3725e7a1304b01d8d9211b2b48526f Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Thu, 4 Apr 2024 22:33:30 +0200 Subject: [PATCH] Transient version of ReentrancyGuard (#4988) Co-authored-by: ernestognw --- .changeset/witty-chicken-smile.md | 5 ++ contracts/mocks/ReentrancyTransientMock.sol | 50 +++++++++++ contracts/utils/README.adoc | 3 + contracts/utils/ReentrancyGuard.sol | 3 + contracts/utils/ReentrancyGuardTransient.sol | 58 +++++++++++++ test/utils/ReentrancyGuard.test.js | 87 ++++++++++---------- 6 files changed, 164 insertions(+), 42 deletions(-) create mode 100644 .changeset/witty-chicken-smile.md create mode 100644 contracts/mocks/ReentrancyTransientMock.sol create mode 100644 contracts/utils/ReentrancyGuardTransient.sol diff --git a/.changeset/witty-chicken-smile.md b/.changeset/witty-chicken-smile.md new file mode 100644 index 000000000..6fae3e744 --- /dev/null +++ b/.changeset/witty-chicken-smile.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`ReentrancyGuardTransient`: Added a variant of `ReentrancyGuard` that uses transient storage. diff --git a/contracts/mocks/ReentrancyTransientMock.sol b/contracts/mocks/ReentrancyTransientMock.sol new file mode 100644 index 000000000..f0e61ea8c --- /dev/null +++ b/contracts/mocks/ReentrancyTransientMock.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.24; + +import {ReentrancyGuardTransient} from "../utils/ReentrancyGuardTransient.sol"; +import {ReentrancyAttack} from "./ReentrancyAttack.sol"; + +contract ReentrancyTransientMock is ReentrancyGuardTransient { + uint256 public counter; + + constructor() { + counter = 0; + } + + function callback() external nonReentrant { + _count(); + } + + function countLocalRecursive(uint256 n) public nonReentrant { + if (n > 0) { + _count(); + countLocalRecursive(n - 1); + } + } + + function countThisRecursive(uint256 n) public nonReentrant { + if (n > 0) { + _count(); + (bool success, ) = address(this).call(abi.encodeCall(this.countThisRecursive, (n - 1))); + require(success, "ReentrancyTransientMock: failed call"); + } + } + + function countAndCall(ReentrancyAttack attacker) public nonReentrant { + _count(); + attacker.callSender(abi.encodeCall(this.callback, ())); + } + + function _count() private { + counter += 1; + } + + function guardedCheckEntered() public nonReentrant { + require(_reentrancyGuardEntered()); + } + + function unguardedCheckNotEntered() public view { + require(!_reentrancyGuardEntered()); + } +} diff --git a/contracts/utils/README.adoc b/contracts/utils/README.adoc index 98747b41f..5e8f183a9 100644 --- a/contracts/utils/README.adoc +++ b/contracts/utils/README.adoc @@ -13,6 +13,7 @@ Miscellaneous contracts and libraries containing utility functions you can use t * {MerkleProof}: Functions for verifying https://en.wikipedia.org/wiki/Merkle_tree[Merkle Tree] proofs. * {EIP712}: Contract with functions to allow processing signed typed structure data according to https://eips.ethereum.org/EIPS/eip-712[EIP-712]. * {ReentrancyGuard}: A modifier that can prevent reentrancy during certain functions. + * {ReentrancyGuardTransient}: Variant of {ReentrancyGuard} that uses transient storage (https://eips.ethereum.org/EIPS/eip-1153[EIP-1153]). * {Pausable}: A common emergency response mechanism that can pause functionality while a remediation is pending. * {Nonces}: Utility for tracking and verifying address nonces that only increment. * {ERC165, ERC165Checker}: Utilities for inspecting interfaces supported by contracts. @@ -65,6 +66,8 @@ Because Solidity does not support generic types, {EnumerableMap} and {Enumerable {{ReentrancyGuard}} +{{ReentrancyGuardTransient}} + {{Pausable}} {{Nonces}} diff --git a/contracts/utils/ReentrancyGuard.sol b/contracts/utils/ReentrancyGuard.sol index 291d92fd5..081851170 100644 --- a/contracts/utils/ReentrancyGuard.sol +++ b/contracts/utils/ReentrancyGuard.sol @@ -15,6 +15,9 @@ pragma solidity ^0.8.20; * those functions `private`, and then adding `external` `nonReentrant` entry * points to them. * + * TIP: If EIP-1153 (transient storage) is available on the chain you're deploying at, + * consider using {ReentrancyGuardTransient} instead. + * * TIP: If you would like to learn more about reentrancy and alternative ways * to protect against it, check out our blog post * https://blog.openzeppelin.com/reentrancy-after-istanbul/[Reentrancy After Istanbul]. diff --git a/contracts/utils/ReentrancyGuardTransient.sol b/contracts/utils/ReentrancyGuardTransient.sol new file mode 100644 index 000000000..0389b8620 --- /dev/null +++ b/contracts/utils/ReentrancyGuardTransient.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.24; + +import {StorageSlot} from "./StorageSlot.sol"; + +/** + * @dev Variant of {ReentrancyGuard} that uses transient storage. + * + * NOTE: This variant only works on networks where EIP-1153 is available. + */ +abstract contract ReentrancyGuardTransient { + using StorageSlot for *; + + // keccak256(abi.encode(uint256(keccak256("openzeppelin.storage.ReentrancyGuard")) - 1)) & ~bytes32(uint256(0xff)) + bytes32 private constant REENTRANCY_GUARD_STORAGE = + 0x9b779b17422d0df92223018b32b4d1fa46e071723d6817e2486d003becc55f00; + + /** + * @dev Unauthorized reentrant call. + */ + error ReentrancyGuardReentrantCall(); + + /** + * @dev Prevents a contract from calling itself, directly or indirectly. + * Calling a `nonReentrant` function from another `nonReentrant` + * function is not supported. It is possible to prevent this from happening + * by making the `nonReentrant` function external, and making it call a + * `private` function that does the actual work. + */ + modifier nonReentrant() { + _nonReentrantBefore(); + _; + _nonReentrantAfter(); + } + + function _nonReentrantBefore() private { + // On the first call to nonReentrant, _status will be NOT_ENTERED + if (_reentrancyGuardEntered()) { + revert ReentrancyGuardReentrantCall(); + } + + // Any calls to nonReentrant after this point will fail + REENTRANCY_GUARD_STORAGE.asBoolean().tstore(true); + } + + function _nonReentrantAfter() private { + REENTRANCY_GUARD_STORAGE.asBoolean().tstore(false); + } + + /** + * @dev Returns true if the reentrancy guard is currently set to "entered", which indicates there is a + * `nonReentrant` function in the call stack. + */ + function _reentrancyGuardEntered() internal view returns (bool) { + return REENTRANCY_GUARD_STORAGE.asBoolean().tload(); + } +} diff --git a/test/utils/ReentrancyGuard.test.js b/test/utils/ReentrancyGuard.test.js index 871967e2f..c4418563e 100644 --- a/test/utils/ReentrancyGuard.test.js +++ b/test/utils/ReentrancyGuard.test.js @@ -2,46 +2,49 @@ const { ethers } = require('hardhat'); const { expect } = require('chai'); const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); -async function fixture() { - const mock = await ethers.deployContract('ReentrancyMock'); - return { mock }; +for (const variant of ['', 'Transient']) { + describe(`Reentrancy${variant}Guard`, function () { + async function fixture() { + const name = `Reentrancy${variant}Mock`; + const mock = await ethers.deployContract(name); + return { name, mock }; + } + + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + it('nonReentrant function can be called', async function () { + expect(await this.mock.counter()).to.equal(0n); + await this.mock.callback(); + expect(await this.mock.counter()).to.equal(1n); + }); + + it('does not allow remote callback', async function () { + const attacker = await ethers.deployContract('ReentrancyAttack'); + await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call'); + }); + + it('_reentrancyGuardEntered should be true when guarded', async function () { + await this.mock.guardedCheckEntered(); + }); + + it('_reentrancyGuardEntered should be false when unguarded', async function () { + await this.mock.unguardedCheckNotEntered(); + }); + + // The following are more side-effects than intended behavior: + // I put them here as documentation, and to monitor any changes + // in the side-effects. + it('does not allow local recursion', async function () { + await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError( + this.mock, + 'ReentrancyGuardReentrantCall', + ); + }); + + it('does not allow indirect local recursion', async function () { + await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith(`${this.name}: failed call`); + }); + }); } - -describe('ReentrancyGuard', function () { - beforeEach(async function () { - Object.assign(this, await loadFixture(fixture)); - }); - - it('nonReentrant function can be called', async function () { - expect(await this.mock.counter()).to.equal(0n); - await this.mock.callback(); - expect(await this.mock.counter()).to.equal(1n); - }); - - it('does not allow remote callback', async function () { - const attacker = await ethers.deployContract('ReentrancyAttack'); - await expect(this.mock.countAndCall(attacker)).to.be.revertedWith('ReentrancyAttack: failed call'); - }); - - it('_reentrancyGuardEntered should be true when guarded', async function () { - await this.mock.guardedCheckEntered(); - }); - - it('_reentrancyGuardEntered should be false when unguarded', async function () { - await this.mock.unguardedCheckNotEntered(); - }); - - // The following are more side-effects than intended behavior: - // I put them here as documentation, and to monitor any changes - // in the side-effects. - it('does not allow local recursion', async function () { - await expect(this.mock.countLocalRecursive(10n)).to.be.revertedWithCustomError( - this.mock, - 'ReentrancyGuardReentrantCall', - ); - }); - - it('does not allow indirect local recursion', async function () { - await expect(this.mock.countThisRecursive(10n)).to.be.revertedWith('ReentrancyMock: failed call'); - }); -});