Add Math.modExp and a Panic library (#3298)
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com> Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
|
||||
pragma solidity ^0.8.20;
|
||||
|
||||
import {Test} from "forge-std/Test.sol";
|
||||
import {Test, stdError} from "forge-std/Test.sol";
|
||||
|
||||
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
|
||||
|
||||
@ -186,7 +186,7 @@ contract MathTest is Test {
|
||||
// Full precision for q * d
|
||||
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
|
||||
// Add remainder of x * y / d (computed as rem = (x * y % d))
|
||||
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d));
|
||||
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
|
||||
uint256 qdRemHi = qdHi + c;
|
||||
|
||||
// Full precision check that x * y = q * d + rem
|
||||
@ -201,14 +201,42 @@ contract MathTest is Test {
|
||||
vm.assume(xyHi >= d);
|
||||
|
||||
// we are outside the scope of {testMulDiv}, we expect muldiv to revert
|
||||
try this.muldiv(x, y, d) returns (uint256) {
|
||||
fail();
|
||||
} catch {}
|
||||
vm.expectRevert(d == 0 ? stdError.divisionError : stdError.arithmeticError);
|
||||
Math.mulDiv(x, y, d);
|
||||
}
|
||||
|
||||
// External call
|
||||
function muldiv(uint256 x, uint256 y, uint256 d) external pure returns (uint256) {
|
||||
return Math.mulDiv(x, y, d);
|
||||
// MOD EXP
|
||||
function testModExp(uint256 b, uint256 e, uint256 m) public {
|
||||
if (m == 0) {
|
||||
vm.expectRevert(stdError.divisionError);
|
||||
}
|
||||
uint256 result = Math.modExp(b, e, m);
|
||||
assertLt(result, m);
|
||||
assertEq(result, _nativeModExp(b, e, m));
|
||||
}
|
||||
|
||||
function testTryModExp(uint256 b, uint256 e, uint256 m) public {
|
||||
(bool success, uint256 result) = Math.tryModExp(b, e, m);
|
||||
assertEq(success, m != 0);
|
||||
if (success) {
|
||||
assertLt(result, m);
|
||||
assertEq(result, _nativeModExp(b, e, m));
|
||||
} else {
|
||||
assertEq(result, 0);
|
||||
}
|
||||
}
|
||||
|
||||
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
|
||||
if (m == 1) return 0;
|
||||
uint256 r = 1;
|
||||
while (e > 0) {
|
||||
if (e % 2 > 0) {
|
||||
r = mulmod(r, b, m);
|
||||
}
|
||||
b = mulmod(b, b, m);
|
||||
e >>= 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
// Helpers
|
||||
@ -217,12 +245,6 @@ contract MathTest is Test {
|
||||
return Math.Rounding(r);
|
||||
}
|
||||
|
||||
function _mulmod(uint256 x, uint256 y, uint256 z) private pure returns (uint256 r) {
|
||||
assembly {
|
||||
r := mulmod(x, y, z)
|
||||
}
|
||||
}
|
||||
|
||||
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
|
||||
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
|
||||
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);
|
||||
|
||||
@ -141,6 +141,24 @@ describe('Math', function () {
|
||||
});
|
||||
});
|
||||
|
||||
describe('tryModExp', function () {
|
||||
it('is correctly returning true and calculating modulus', async function () {
|
||||
const base = 3n;
|
||||
const exponent = 200n;
|
||||
const modulus = 50n;
|
||||
|
||||
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
|
||||
});
|
||||
|
||||
it('is correctly returning false when modulus is 0', async function () {
|
||||
const base = 3n;
|
||||
const exponent = 200n;
|
||||
const modulus = 0n;
|
||||
|
||||
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('max', function () {
|
||||
it('is correctly detected in both position', async function () {
|
||||
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
|
||||
@ -222,7 +240,7 @@ describe('Math', function () {
|
||||
});
|
||||
});
|
||||
|
||||
describe('muldiv', function () {
|
||||
describe('mulDiv', function () {
|
||||
it('divide by 0', async function () {
|
||||
const a = 1n;
|
||||
const b = 1n;
|
||||
@ -234,9 +252,8 @@ describe('Math', function () {
|
||||
const a = 5n;
|
||||
const b = ethers.MaxUint256;
|
||||
const c = 2n;
|
||||
await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithCustomError(
|
||||
this.mock,
|
||||
'MathOverflowedMulDiv',
|
||||
await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithPanic(
|
||||
PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
|
||||
);
|
||||
});
|
||||
|
||||
@ -336,6 +353,24 @@ describe('Math', function () {
|
||||
}
|
||||
});
|
||||
|
||||
describe('modExp', function () {
|
||||
it('is correctly calculating modulus', async function () {
|
||||
const base = 3n;
|
||||
const exponent = 200n;
|
||||
const modulus = 50n;
|
||||
|
||||
expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
|
||||
});
|
||||
|
||||
it('is correctly reverting when modulus is zero', async function () {
|
||||
const base = 3n;
|
||||
const exponent = 200n;
|
||||
const modulus = 0n;
|
||||
|
||||
await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sqrt', function () {
|
||||
it('rounds down', async function () {
|
||||
for (const rounding of RoundingDown) {
|
||||
|
||||
Reference in New Issue
Block a user