Add 512bits add and mult operations (#5035)
This commit is contained in:
@ -11,6 +11,48 @@ contract MathTest is Test {
|
||||
assertEq(Math.ternary(f, a, b), f ? a : b);
|
||||
}
|
||||
|
||||
// ADD512 & MUL512
|
||||
function testAdd512(uint256 a, uint256 b) public pure {
|
||||
(uint256 high, uint256 low) = Math.add512(a, b);
|
||||
|
||||
// test against tryAdd
|
||||
(bool success, uint256 result) = Math.tryAdd(a, b);
|
||||
if (success) {
|
||||
assertEq(high, 0);
|
||||
assertEq(low, result);
|
||||
} else {
|
||||
assertEq(high, 1);
|
||||
}
|
||||
|
||||
// test against unchecked
|
||||
unchecked {
|
||||
assertEq(low, a + b); // unchecked allow overflow
|
||||
}
|
||||
}
|
||||
|
||||
function testMul512(uint256 a, uint256 b) public pure {
|
||||
(uint256 high, uint256 low) = Math.mul512(a, b);
|
||||
|
||||
// test against tryMul
|
||||
(bool success, uint256 result) = Math.tryMul(a, b);
|
||||
if (success) {
|
||||
assertEq(high, 0);
|
||||
assertEq(low, result);
|
||||
} else {
|
||||
assertGt(high, 0);
|
||||
}
|
||||
|
||||
// test against unchecked
|
||||
unchecked {
|
||||
assertEq(low, a * b); // unchecked allow overflow
|
||||
}
|
||||
|
||||
// test against alternative method
|
||||
(uint256 _high, uint256 _low) = _mulKaratsuba(a, b);
|
||||
assertEq(high, _high);
|
||||
assertEq(low, _low);
|
||||
}
|
||||
|
||||
// MIN & MAX
|
||||
function testSymbolicMinMax(uint256 a, uint256 b) public pure {
|
||||
assertEq(Math.min(a, b), a < b ? a : b);
|
||||
@ -184,7 +226,7 @@ contract MathTest is Test {
|
||||
// MULDIV
|
||||
function testMulDiv(uint256 x, uint256 y, uint256 d) public pure {
|
||||
// Full precision for x * y
|
||||
(uint256 xyHi, uint256 xyLo) = _mulHighLow(x, y);
|
||||
(uint256 xyHi, uint256 xyLo) = Math.mul512(x, y);
|
||||
|
||||
// Assume result won't overflow (see {testMulDivDomain})
|
||||
// This also checks that `d` is positive
|
||||
@ -194,9 +236,9 @@ contract MathTest is Test {
|
||||
uint256 q = Math.mulDiv(x, y, d);
|
||||
|
||||
// Full precision for q * d
|
||||
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
|
||||
(uint256 qdHi, uint256 qdLo) = Math.mul512(q, d);
|
||||
// Add remainder of x * y / d (computed as rem = (x * y % d))
|
||||
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
|
||||
(uint256 c, uint256 qdRemLo) = Math.add512(qdLo, mulmod(x, y, d));
|
||||
uint256 qdRemHi = qdHi + c;
|
||||
|
||||
// Full precision check that x * y = q * d + rem
|
||||
@ -206,7 +248,7 @@ contract MathTest is Test {
|
||||
|
||||
/// forge-config: default.allow_internal_expect_revert = true
|
||||
function testMulDivDomain(uint256 x, uint256 y, uint256 d) public {
|
||||
(uint256 xyHi, ) = _mulHighLow(x, y);
|
||||
(uint256 xyHi, ) = Math.mul512(x, y);
|
||||
|
||||
// Violate {testMulDiv} assumption (covers d is 0 and result overflow)
|
||||
vm.assume(xyHi >= d);
|
||||
@ -266,26 +308,13 @@ contract MathTest is Test {
|
||||
}
|
||||
}
|
||||
|
||||
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
|
||||
if (m == 1) return 0;
|
||||
uint256 r = 1;
|
||||
while (e > 0) {
|
||||
if (e % 2 > 0) {
|
||||
r = mulmod(r, b, m);
|
||||
}
|
||||
b = mulmod(b, b, m);
|
||||
e >>= 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function _asRounding(uint8 r) private pure returns (Math.Rounding) {
|
||||
vm.assume(r < uint8(type(Math.Rounding).max));
|
||||
return Math.Rounding(r);
|
||||
}
|
||||
|
||||
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
|
||||
function _mulKaratsuba(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
|
||||
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
|
||||
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
|
||||
|
||||
@ -305,10 +334,16 @@ contract MathTest is Test {
|
||||
}
|
||||
}
|
||||
|
||||
function _addCarry(uint256 x, uint256 y) private pure returns (uint256 res, uint256 carry) {
|
||||
unchecked {
|
||||
res = x + y;
|
||||
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
|
||||
if (m == 1) return 0;
|
||||
uint256 r = 1;
|
||||
while (e > 0) {
|
||||
if (e % 2 > 0) {
|
||||
r = mulmod(r, b, m);
|
||||
}
|
||||
b = mulmod(b, b, m);
|
||||
e >>= 1;
|
||||
}
|
||||
carry = res < x ? 1 : 0;
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user