Add Math.modExp and a Panic library (#3298)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
Mihir Wadekar
2024-02-02 09:40:00 -08:00
committed by GitHub
parent cc431f53e0
commit 192e873fcb
11 changed files with 300 additions and 33 deletions

View File

@ -2,7 +2,7 @@
pragma solidity ^0.8.20;
import {Test} from "forge-std/Test.sol";
import {Test, stdError} from "forge-std/Test.sol";
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
@ -186,7 +186,7 @@ contract MathTest is Test {
// Full precision for q * d
(uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d);
// Add remainder of x * y / d (computed as rem = (x * y % d))
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d));
(uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d));
uint256 qdRemHi = qdHi + c;
// Full precision check that x * y = q * d + rem
@ -201,14 +201,42 @@ contract MathTest is Test {
vm.assume(xyHi >= d);
// we are outside the scope of {testMulDiv}, we expect muldiv to revert
try this.muldiv(x, y, d) returns (uint256) {
fail();
} catch {}
vm.expectRevert(d == 0 ? stdError.divisionError : stdError.arithmeticError);
Math.mulDiv(x, y, d);
}
// External call
function muldiv(uint256 x, uint256 y, uint256 d) external pure returns (uint256) {
return Math.mulDiv(x, y, d);
// MOD EXP
function testModExp(uint256 b, uint256 e, uint256 m) public {
if (m == 0) {
vm.expectRevert(stdError.divisionError);
}
uint256 result = Math.modExp(b, e, m);
assertLt(result, m);
assertEq(result, _nativeModExp(b, e, m));
}
function testTryModExp(uint256 b, uint256 e, uint256 m) public {
(bool success, uint256 result) = Math.tryModExp(b, e, m);
assertEq(success, m != 0);
if (success) {
assertLt(result, m);
assertEq(result, _nativeModExp(b, e, m));
} else {
assertEq(result, 0);
}
}
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
if (m == 1) return 0;
uint256 r = 1;
while (e > 0) {
if (e % 2 > 0) {
r = mulmod(r, b, m);
}
b = mulmod(b, b, m);
e >>= 1;
}
return r;
}
// Helpers
@ -217,12 +245,6 @@ contract MathTest is Test {
return Math.Rounding(r);
}
function _mulmod(uint256 x, uint256 y, uint256 z) private pure returns (uint256 r) {
assembly {
r := mulmod(x, y, z)
}
}
function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) {
(uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128);
(uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128);

View File

@ -141,6 +141,24 @@ describe('Math', function () {
});
});
describe('tryModExp', function () {
it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
});
it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
});
});
describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
@ -222,7 +240,7 @@ describe('Math', function () {
});
});
describe('muldiv', function () {
describe('mulDiv', function () {
it('divide by 0', async function () {
const a = 1n;
const b = 1n;
@ -234,9 +252,8 @@ describe('Math', function () {
const a = 5n;
const b = ethers.MaxUint256;
const c = 2n;
await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithCustomError(
this.mock,
'MathOverflowedMulDiv',
await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithPanic(
PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
);
});
@ -336,6 +353,24 @@ describe('Math', function () {
}
});
describe('modExp', function () {
it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
});
it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
});
});
describe('sqrt', function () {
it('rounds down', async function () {
for (const rounding of RoundingDown) {