diff --git a/foundry.toml b/foundry.toml new file mode 100644 index 000000000..1076ef55b --- /dev/null +++ b/foundry.toml @@ -0,0 +1,2 @@ +[fuzz] +runs = 10000 diff --git a/package.json b/package.json index 249d96078..84c978e9a 100644 --- a/package.json +++ b/package.json @@ -20,8 +20,8 @@ "lint:fix": "npm run lint:js:fix && npm run lint:sol:fix", "lint:js": "eslint --ignore-path .gitignore .", "lint:js:fix": "eslint --ignore-path .gitignore . --fix", - "lint:sol": "solhint 'contracts/**/*.sol' && prettier -c 'contracts/**/*.sol'", - "lint:sol:fix": "prettier --write \"contracts/**/*.sol\"", + "lint:sol": "solhint '{contracts,test}/**/*.sol' && prettier -c '{contracts,test}/**/*.sol'", + "lint:sol:fix": "prettier --write '{contracts,test}/**/*.sol'", "clean": "hardhat clean && rimraf build contracts/build", "prepare": "scripts/prepare.sh", "prepack": "scripts/prepack.sh", diff --git a/test/utils/math/Math.t.sol b/test/utils/math/Math.t.sol index c72cc88ee..fa682149f 100644 --- a/test/utils/math/Math.t.sol +++ b/test/utils/math/Math.t.sol @@ -8,6 +8,22 @@ import "../../../contracts/utils/math/Math.sol"; import "../../../contracts/utils/math/SafeMath.sol"; contract MathTest is Test { + // CEILDIV + function testCeilDiv(uint256 a, uint256 b) public { + vm.assume(b > 0); + + uint256 result = Math.ceilDiv(a, b); + + if (result == 0) { + assertEq(a, 0); + } else { + uint256 maxdiv = UINT256_MAX / b; + bool overflow = maxdiv * b < a; + assertTrue(a > b * (result - 1)); + assertTrue(overflow ? result == maxdiv + 1 : a <= b * result); + } + } + // SQRT function testSqrt(uint256 input, uint8 r) public { Math.Rounding rounding = _asRounding(r); @@ -120,9 +136,98 @@ contract MathTest is Test { return 256**value < ref; } + // MULDIV + function testMulDiv( + uint256 x, + uint256 y, + uint256 d + ) public { + // Full precision for x * y + (uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y); + + // Assume result won't overflow (see {testMulDivDomain}) + // This also checks that `d` is positive + vm.assume(xyHi < d); + + // Perform muldiv + uint256 q = Math.mulDiv(x, y, d); + + // Full precision for q * d + (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d); + // Add reminder of x * y / d (computed as rem = (x * y % d)) + (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d)); + uint256 qdRemHi = qdHi + c; + + // Full precision check that x * y = q * d + rem + assertEq(xyHi, qdRemHi); + assertEq(xyLo, qdRemLo); + } + + function testMulDivDomain( + uint256 x, + uint256 y, + uint256 d + ) public { + (uint256 xyHi, ) = _mulHighLow(x, y); + + // Violate {testMulDiv} assumption (covers d is 0 and result overflow) + vm.assume(xyHi >= d); + + // we are outside the scope of {testMulDiv}, we expect muldiv to revert + try this.muldiv(x, y, d) returns (uint256) { + fail(); + } catch {} + } + + // External call + function muldiv( + uint256 x, + uint256 y, + uint256 d + ) external pure returns (uint256) { + return Math.mulDiv(x, y, d); + } + // Helpers function _asRounding(uint8 r) private returns (Math.Rounding) { vm.assume(r < uint8(type(Math.Rounding).max)); return Math.Rounding(r); } + + function _mulmod( + uint256 x, + uint256 y, + uint256 z + ) private pure returns (uint256 r) { + assembly { + r := mulmod(x, y, z) + } + } + + function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { + (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128); + (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128); + + // Karatsuba algorithm + // https://en.wikipedia.org/wiki/Karatsuba_algorithm + uint256 z2 = x1 * y1; + uint256 z1a = x1 * y0; + uint256 z1b = x0 * y1; + uint256 z0 = x0 * y0; + + uint256 carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128; + + high = z2 + (z1a >> 128) + (z1b >> 128) + carry; + + unchecked { + low = x * y; + } + } + + function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) { + unchecked { + res = x + y; + } + carry = res < x ? 1 : 0; + } }