diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index fd8d5e31b..988c18bce 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -56,6 +56,19 @@ jobs: with: token: ${{ github.token }} + foundry-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + with: + version: nightly + - name: Run tests + run: forge test -vv + coverage: if: github.repository != 'OpenZeppelin/openzeppelin-contracts-upgradeable' runs-on: ubuntu-latest diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..888d42dcd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lib/forge-std"] + path = lib/forge-std + url = https://github.com/foundry-rs/forge-std diff --git a/contracts/mocks/MulticallTest.sol b/contracts/mocks/MulticallTest.sol index f1a3a9cfe..4e527eff1 100644 --- a/contracts/mocks/MulticallTest.sol +++ b/contracts/mocks/MulticallTest.sol @@ -5,7 +5,7 @@ pragma solidity ^0.8.0; import "./MulticallTokenMock.sol"; contract MulticallTest { - function testReturnValues( + function checkReturnValues( MulticallTokenMock multicallToken, address[] calldata recipients, uint256[] calldata amounts diff --git a/hardhat/skip-foundry-tests.js b/hardhat/skip-foundry-tests.js new file mode 100644 index 000000000..b8030288d --- /dev/null +++ b/hardhat/skip-foundry-tests.js @@ -0,0 +1,7 @@ +const { subtask } = require('hardhat/config'); +const { TASK_COMPILE_SOLIDITY_GET_SOURCE_PATHS } = require('hardhat/builtin-tasks/task-names'); + +subtask(TASK_COMPILE_SOLIDITY_GET_SOURCE_PATHS) + .setAction(async (_, __, runSuper) => + (await runSuper()).filter((path) => !path.endsWith('.t.sol')), + ); diff --git a/lib/forge-std b/lib/forge-std new file mode 160000 index 000000000..ca8d6e00e --- /dev/null +++ b/lib/forge-std @@ -0,0 +1 @@ +Subproject commit ca8d6e00ea9cb035f6856ff732203c9a3c48b966 diff --git a/test/utils/Multicall.test.js b/test/utils/Multicall.test.js index c6453bb61..61c26344c 100644 --- a/test/utils/Multicall.test.js +++ b/test/utils/Multicall.test.js @@ -31,7 +31,7 @@ contract('MulticallToken', function (accounts) { const recipients = [alice, bob]; const amounts = [amount / 2, amount / 3].map(n => new BN(n)); - await multicallTest.testReturnValues(this.multicallToken.address, recipients, amounts); + await multicallTest.checkReturnValues(this.multicallToken.address, recipients, amounts); }); it('reverts previous calls', async function () { diff --git a/test/utils/math/Math.t.sol b/test/utils/math/Math.t.sol new file mode 100644 index 000000000..2baa3f16c --- /dev/null +++ b/test/utils/math/Math.t.sol @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; + +import "../../../contracts/utils/math/Math.sol"; +import "../../../contracts/utils/math/SafeMath.sol"; + +contract MathTest is Test { + // SQRT + function testSqrt(uint256 input, uint8 r) public { + Math.Rounding rounding = _asRounding(r); + + uint256 result = Math.sqrt(input, rounding); + + // square of result is bigger than input + if (_squareBigger(result, input)) { + assertTrue(rounding == Math.Rounding.Up); + assertTrue(_squareSmaller(result - 1, input)); + } + // square of result is smaller than input + else if (_squareSmaller(result, input)) { + assertFalse(rounding == Math.Rounding.Up); + assertTrue(_squareBigger(result + 1, input)); + } + } + + function _squareBigger(uint256 value, uint256 ref) private pure returns (bool) { + (bool noOverflow, uint256 square) = SafeMath.tryMul(value, value); + return !noOverflow || square > ref; + } + + function _squareSmaller(uint256 value, uint256 ref) private pure returns (bool) { + return value * value < ref; + } + + // LOG2 + function testLog2(uint256 input, uint8 r) public { + Math.Rounding rounding = _asRounding(r); + + uint256 result = Math.log2(input, rounding); + + if (input == 0) { + assertEq(result, 0); + } else if (_powerOf2Bigger(result, input)) { + assertTrue(rounding == Math.Rounding.Up); + assertTrue(_powerOf2Smaller(result - 1, input)); + } else if (_powerOf2Smaller(result, input)) { + assertFalse(rounding == Math.Rounding.Up); + assertTrue(_powerOf2Bigger(result + 1, input)); + } + } + + function _powerOf2Bigger(uint256 value, uint256 ref) private pure returns (bool) { + return value >= 256 || 2**value > ref; // 2**256 overflows uint256 + } + + function _powerOf2Smaller(uint256 value, uint256 ref) private pure returns (bool) { + return 2**value < ref; + } + + // LOG10 + function testLog10(uint256 input, uint8 r) public { + Math.Rounding rounding = _asRounding(r); + + uint256 result = Math.log10(input, rounding); + + if (input == 0) { + assertEq(result, 0); + } else if (_powerOf10Bigger(result, input)) { + assertTrue(rounding == Math.Rounding.Up); + assertTrue(_powerOf10Smaller(result - 1, input)); + } else if (_powerOf10Smaller(result, input)) { + assertFalse(rounding == Math.Rounding.Up); + assertTrue(_powerOf10Bigger(result + 1, input)); + } + } + + function _powerOf10Bigger(uint256 value, uint256 ref) private pure returns (bool) { + return value >= 78 || 10**value > ref; // 10**78 overflows uint256 + } + + function _powerOf10Smaller(uint256 value, uint256 ref) private pure returns (bool) { + return 10**value < ref; + } + + // LOG256 + function testLog256(uint256 input, uint8 r) public { + Math.Rounding rounding = _asRounding(r); + + uint256 result = Math.log256(input, rounding); + + if (input == 0) { + assertEq(result, 0); + } else if (_powerOf256Bigger(result, input)) { + assertTrue(rounding == Math.Rounding.Up); + assertTrue(_powerOf256Smaller(result - 1, input)); + } else if (_powerOf256Smaller(result, input)) { + assertFalse(rounding == Math.Rounding.Up); + assertTrue(_powerOf256Bigger(result + 1, input)); + } + } + + function _powerOf256Bigger(uint256 value, uint256 ref) private pure returns (bool) { + return value >= 32 || 256**value > ref; // 256**32 overflows uint256 + } + + function _powerOf256Smaller(uint256 value, uint256 ref) private pure returns (bool) { + return 256**value < ref; + } + + // Helpers + function _asRounding(uint8 r) private returns (Math.Rounding) { + vm.assume(r < uint8(type(Math.Rounding).max)); + return Math.Rounding(r); + } +}