Add bytes memory version of Math.modExp (#4893)
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
@ -2,4 +2,4 @@
|
|||||||
'openzeppelin-solidity': minor
|
'openzeppelin-solidity': minor
|
||||||
---
|
---
|
||||||
|
|
||||||
`Math`: Add `modExp` function that exposes the `EIP-198` precompile.
|
`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions.
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
pragma solidity ^0.8.20;
|
pragma solidity ^0.8.20;
|
||||||
|
|
||||||
import {Address} from "../Address.sol";
|
|
||||||
import {Panic} from "../Panic.sol";
|
import {Panic} from "../Panic.sol";
|
||||||
import {SafeCast} from "./SafeCast.sol";
|
import {SafeCast} from "./SafeCast.sol";
|
||||||
|
|
||||||
@ -289,11 +288,7 @@ library Math {
|
|||||||
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
|
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
|
||||||
(bool success, uint256 result) = tryModExp(b, e, m);
|
(bool success, uint256 result) = tryModExp(b, e, m);
|
||||||
if (!success) {
|
if (!success) {
|
||||||
if (m == 0) {
|
Panic.panic(Panic.DIVISION_BY_ZERO);
|
||||||
Panic.panic(Panic.DIVISION_BY_ZERO);
|
|
||||||
} else {
|
|
||||||
revert Address.FailedInnerCall();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -335,6 +330,57 @@ library Math {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Variant of {modExp} that supports inputs of arbitrary length.
|
||||||
|
*/
|
||||||
|
function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
|
||||||
|
(bool success, bytes memory result) = tryModExp(b, e, m);
|
||||||
|
if (!success) {
|
||||||
|
Panic.panic(Panic.DIVISION_BY_ZERO);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Variant of {tryModExp} that supports inputs of arbitrary length.
|
||||||
|
*/
|
||||||
|
function tryModExp(
|
||||||
|
bytes memory b,
|
||||||
|
bytes memory e,
|
||||||
|
bytes memory m
|
||||||
|
) internal view returns (bool success, bytes memory result) {
|
||||||
|
if (_zeroBytes(m)) return (false, new bytes(0));
|
||||||
|
|
||||||
|
uint256 mLen = m.length;
|
||||||
|
|
||||||
|
// Encode call args in result and move the free memory pointer
|
||||||
|
result = abi.encodePacked(b.length, e.length, mLen, b, e, m);
|
||||||
|
|
||||||
|
/// @solidity memory-safe-assembly
|
||||||
|
assembly {
|
||||||
|
let dataPtr := add(result, 0x20)
|
||||||
|
// Write result on top of args to avoid allocating extra memory.
|
||||||
|
success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
|
||||||
|
// Overwrite the length.
|
||||||
|
// result.length > returndatasize() is guaranteed because returndatasize() == m.length
|
||||||
|
mstore(result, mLen)
|
||||||
|
// Set the memory pointer after the returned data.
|
||||||
|
mstore(0x40, add(dataPtr, mLen))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Returns whether the provided byte array is zero.
|
||||||
|
*/
|
||||||
|
function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
|
||||||
|
for (uint256 i = 0; i < byteArray.length; ++i) {
|
||||||
|
if (byteArray[i] != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
|
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
|
||||||
* towards zero.
|
* towards zero.
|
||||||
|
|||||||
@ -3,8 +3,31 @@ const max = (...values) => values.slice(1).reduce((x, y) => (x > y ? x : y), val
|
|||||||
const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0));
|
const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0));
|
||||||
const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0));
|
const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0));
|
||||||
|
|
||||||
|
// Computes modexp without BigInt overflow for large numbers
|
||||||
|
function modExp(b, e, m) {
|
||||||
|
let result = 1n;
|
||||||
|
|
||||||
|
// If e is a power of two, modexp can be calculated as:
|
||||||
|
// for (let result = b, i = 0; i < log2(e); i++) result = modexp(result, 2, m)
|
||||||
|
//
|
||||||
|
// Given any natural number can be written in terms of powers of 2 (i.e. binary)
|
||||||
|
// then modexp can be calculated for any e, by multiplying b**i for all i where
|
||||||
|
// binary(e)[i] is 1 (i.e. a power of two).
|
||||||
|
for (let base = b % m; e > 0n; base = base ** 2n % m) {
|
||||||
|
// Least significant bit is 1
|
||||||
|
if (e % 2n == 1n) {
|
||||||
|
result = (result * base) % m;
|
||||||
|
}
|
||||||
|
|
||||||
|
e /= 2n; // Binary pop
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
min,
|
min,
|
||||||
max,
|
max,
|
||||||
sum,
|
sum,
|
||||||
|
modExp,
|
||||||
};
|
};
|
||||||
|
|||||||
@ -226,6 +226,33 @@ contract MathTest is Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function testModExpMemory(uint256 b, uint256 e, uint256 m) public {
|
||||||
|
if (m == 0) {
|
||||||
|
vm.expectRevert(stdError.divisionError);
|
||||||
|
}
|
||||||
|
bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m));
|
||||||
|
assertEq(result.length, 0x20);
|
||||||
|
uint256 res = abi.decode(result, (uint256));
|
||||||
|
assertLt(res, m);
|
||||||
|
assertEq(res, _nativeModExp(b, e, m));
|
||||||
|
}
|
||||||
|
|
||||||
|
function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public {
|
||||||
|
(bool success, bytes memory result) = Math.tryModExp(
|
||||||
|
abi.encodePacked(b),
|
||||||
|
abi.encodePacked(e),
|
||||||
|
abi.encodePacked(m)
|
||||||
|
);
|
||||||
|
if (success) {
|
||||||
|
assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20
|
||||||
|
uint256 res = abi.decode(result, (uint256));
|
||||||
|
assertLt(res, m);
|
||||||
|
assertEq(res, _nativeModExp(b, e, m));
|
||||||
|
} else {
|
||||||
|
assertEq(result.length, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
|
function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) {
|
||||||
if (m == 1) return 0;
|
if (m == 1) return 0;
|
||||||
uint256 r = 1;
|
uint256 r = 1;
|
||||||
|
|||||||
@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers');
|
|||||||
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
|
const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
|
||||||
|
|
||||||
const { Rounding } = require('../../helpers/enums');
|
const { Rounding } = require('../../helpers/enums');
|
||||||
const { min, max } = require('../../helpers/math');
|
const { min, max, modExp } = require('../../helpers/math');
|
||||||
const { generators } = require('../../helpers/random');
|
const { generators } = require('../../helpers/random');
|
||||||
|
const { range } = require('../../../scripts/helpers');
|
||||||
|
const { product } = require('../../helpers/iterate');
|
||||||
|
|
||||||
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
|
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
|
||||||
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
|
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
|
||||||
|
|
||||||
|
const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width));
|
||||||
|
const uint256 = value => ethers.Typed.uint256(value);
|
||||||
|
bytes.zero = '0x';
|
||||||
|
uint256.zero = 0n;
|
||||||
|
|
||||||
async function testCommutative(fn, lhs, rhs, expected, ...extra) {
|
async function testCommutative(fn, lhs, rhs, expected, ...extra) {
|
||||||
expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
|
expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected);
|
||||||
expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
|
expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected);
|
||||||
@ -141,24 +148,6 @@ describe('Math', function () {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('tryModExp', function () {
|
|
||||||
it('is correctly returning true and calculating modulus', async function () {
|
|
||||||
const base = 3n;
|
|
||||||
const exponent = 200n;
|
|
||||||
const modulus = 50n;
|
|
||||||
|
|
||||||
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('is correctly returning false when modulus is 0', async function () {
|
|
||||||
const base = 3n;
|
|
||||||
const exponent = 200n;
|
|
||||||
const modulus = 0n;
|
|
||||||
|
|
||||||
expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('max', function () {
|
describe('max', function () {
|
||||||
it('is correctly detected in both position', async function () {
|
it('is correctly detected in both position', async function () {
|
||||||
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
|
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
|
||||||
@ -354,20 +343,79 @@ describe('Math', function () {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('modExp', function () {
|
describe('modExp', function () {
|
||||||
it('is correctly calculating modulus', async function () {
|
for (const [name, type] of Object.entries({ uint256, bytes })) {
|
||||||
const base = 3n;
|
describe(`with ${name} inputs`, function () {
|
||||||
const exponent = 200n;
|
it('is correctly calculating modulus', async function () {
|
||||||
const modulus = 50n;
|
const b = 3n;
|
||||||
|
const e = 200n;
|
||||||
|
const m = 50n;
|
||||||
|
|
||||||
expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus);
|
expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('is correctly reverting when modulus is zero', async function () {
|
||||||
|
const b = 3n;
|
||||||
|
const e = 200n;
|
||||||
|
const m = 0n;
|
||||||
|
|
||||||
|
await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic(
|
||||||
|
PANIC_CODES.DIVISION_BY_ZERO,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('with large bytes inputs', function () {
|
||||||
|
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
|
||||||
|
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
)) {
|
||||||
|
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
|
||||||
|
const mLength = ethers.dataLength(ethers.toBeHex(m));
|
||||||
|
|
||||||
|
expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value);
|
||||||
|
});
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it('is correctly reverting when modulus is zero', async function () {
|
describe('tryModExp', function () {
|
||||||
const base = 3n;
|
for (const [name, type] of Object.entries({ uint256, bytes })) {
|
||||||
const exponent = 200n;
|
describe(`with ${name} inputs`, function () {
|
||||||
const modulus = 0n;
|
it('is correctly calculating modulus', async function () {
|
||||||
|
const b = 3n;
|
||||||
|
const e = 200n;
|
||||||
|
const m = 50n;
|
||||||
|
|
||||||
await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO);
|
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('is correctly reverting when modulus is zero', async function () {
|
||||||
|
const b = 3n;
|
||||||
|
const e = 200n;
|
||||||
|
const m = 0n;
|
||||||
|
|
||||||
|
expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('with large bytes inputs', function () {
|
||||||
|
for (const [[b, log2b], [e, log2e], [m, log2m]] of product(
|
||||||
|
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]),
|
||||||
|
)) {
|
||||||
|
it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () {
|
||||||
|
const mLength = ethers.dataLength(ethers.toBeHex(m));
|
||||||
|
|
||||||
|
expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([
|
||||||
|
true,
|
||||||
|
bytes(modExp(b, e, m), mLength).value,
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user