Add bytes memory version of Math.modExp (#4893)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
Ernesto García
2024-02-14 03:06:34 -06:00
committed by GitHub
parent ae1bafcb48
commit 4e7e6e54da
5 changed files with 180 additions and 36 deletions

View File

@ -226,6 +226,33 @@ contract MathTest is Test {
}
}
function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
if (m == 0) {
vm.expectRevert(stdError.divisionError);
}
bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
assertEq(result.length, 0x20);
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(res, _nativeModExp(b, e, m));
}
function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
(bool success, bytes memory result) = Math.tryModExp(
abi.encodePacked(b),
abi.encodePacked(e),
abi.encodePacked(m)
);
if (success) {
assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20
uint256 res = abi.decode(result, (uint256));
assertLt(res, m);
assertEq(res, _nativeModExp(b, e, m));
} else {
assertEq(result.length, 0);
}
}
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
if (m == 1) return 0;
uint256 r = 1;

View File

@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
const { Rounding } = require('../../helpers/enums');
const { min, max } = require('../../helpers/math');
const { min, max, modExp } = require('../../helpers/math');
const { generators } = require('../../helpers/random');
const { range } = require('../../../scripts/helpers');
const { product } = require('../../helpers/iterate');
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width));
const uint256 = value => ethers.Typed.uint256(value);
bytes.zero = '0x';
uint256.zero = 0n;
async function testCommutative(fn, lhs, rhs, expected, ...extra) {
expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
@ -141,24 +148,6 @@ describe('Math', function () {
});
});
describe('tryModExp', function () {
it('is correctly returning true and calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
});
it('is correctly returning false when modulus is 0', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
});
});
describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
@ -354,20 +343,79 @@ describe('Math', function () {
});
describe('modExp', function () {
it('is correctly calculating modulus', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 50n;
for (const [name, type] of Object.entries({ uint256, bytes })) {
describe(`with ${name} inputs`, function () {
it('is correctly calculating modulus', async function () {
const b = 3n;
const e = 200n;
const m = 50n;
expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
});
it('is correctly reverting when modulus is zero', async function () {
const b = 3n;
const e = 200n;
const m = 0n;
await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic(
PANIC_CODES.DIVISION_BY_ZERO,
);
});
});
}
describe('with large bytes inputs', function () {
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
)) {
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
const mLength = ethers.dataLength(ethers.toBeHex(m));
expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
});
}
});
});
it('is correctly reverting when modulus is zero', async function () {
const base = 3n;
const exponent = 200n;
const modulus = 0n;
describe('tryModExp', function () {
for (const [name, type] of Object.entries({ uint256, bytes })) {
describe(`with ${name} inputs`, function () {
it('is correctly calculating modulus', async function () {
const b = 3n;
const e = 200n;
const m = 50n;
await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
});
it('is correctly reverting when modulus is zero', async function () {
const b = 3n;
const e = 200n;
const m = 0n;
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
});
});
}
describe('with large bytes inputs', function () {
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
)) {
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
const mLength = ethers.dataLength(ethers.toBeHex(m));
expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
true,
bytes(modExp(b, e, m), mLength).value,
]);
});
}
});
});