Add fuzz testing of mulDiv (#3717)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
Francisco
2022-09-23 18:45:20 -03:00
committed by GitHub
parent 408055dfab
commit c08c6e1b84
3 changed files with 109 additions and 2 deletions

View File

@ -8,6 +8,22 @@ import "../../../contracts/utils/math/Math.sol";
import "../../../contracts/utils/math/SafeMath.sol";
contract MathTest is Test {
// CEILDIV
function testCeilDiv(uint256 a, uint256 b) public {
vm.assume(b > 0);
uint256 result = Math.ceilDiv(a, b);
if (result == 0) {
assertEq(a, 0);
} else {
uint256 maxdiv = UINT256_MAX / b;
bool overflow = maxdiv * b < a;
assertTrue(a > b * (result - 1));
assertTrue(overflow ? result == maxdiv + 1 : a <= b * result);
}
}
// SQRT
function testSqrt(uint256 input, uint8 r) public {
Math.Rounding rounding = _asRounding(r);
@ -120,9 +136,98 @@ contract MathTest is Test {
return 256**value < ref;
}
// MULDIV
function testMulDiv(
uint256 x,
uint256 y,
uint256 d
) public {
// Full precision for x * y
(uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y);
// Assume result won't overflow (see {testMulDivDomain})
// This also checks that `d` is positive
vm.assume(xyHi < d);
// Perform muldiv
uint256 q = Math.mulDiv(x, y, d);
// Full precision for q * d
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
// Add reminder of x * y / d (computed as rem = (x * y % d))
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d));
uint256 qdRemHi = qdHi + c;
// Full precision check that x * y = q * d + rem
assertEq(xyHi, qdRemHi);
assertEq(xyLo, qdRemLo);
}
function testMulDivDomain(
uint256 x,
uint256 y,
uint256 d
) public {
(uint256 xyHi, ) = _mulHighLow(x, y);
// Violate {testMulDiv} assumption (covers d is 0 and result overflow)
vm.assume(xyHi >= d);
// we are outside the scope of {testMulDiv}, we expect muldiv to revert
try this.muldiv(x, y, d) returns (uint256) {
fail();
} catch {}
}
// External call
function muldiv(
uint256 x,
uint256 y,
uint256 d
) external pure returns (uint256) {
return Math.mulDiv(x, y, d);
}
// Helpers
function _asRounding(uint8 r) private returns (Math.Rounding) {
vm.assume(r < uint8(type(Math.Rounding).max));
return Math.Rounding(r);
}
function _mulmod(
uint256 x,
uint256 y,
uint256 z
) private pure returns (uint256 r) {
assembly {
r := mulmod(x, y, z)
}
}
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
// Karatsuba algorithm
// https://en.wikipedia.org/wiki/Karatsuba_algorithm
uint256 z2 = x1 * y1;
uint256 z1a = x1 * y0;
uint256 z1b = x0 * y1;
uint256 z0 = x0 * y0;
uint256 carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128;
high = z2 + (z1a >> 128) + (z1b >> 128) + carry;
unchecked {
low = x * y;
}
}
function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) {
unchecked {
res = x + y;
}
carry = res < x ? 1 : 0;
}
}