Add a Math.inv function that inverse a number in Z/nZ (#4839)
Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
5
.changeset/cool-mangos-compare.md
Normal file
5
.changeset/cool-mangos-compare.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
'openzeppelin-solidity': minor
|
||||
---
|
||||
|
||||
`Math`: add an `invMod` function to get the modular multiplicative inverse of a number in Z/nZ.
|
||||
@ -121,9 +121,10 @@ library Math {
|
||||
}
|
||||
|
||||
/**
|
||||
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
|
||||
* @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
|
||||
* denominator == 0.
|
||||
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
|
||||
*
|
||||
* Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
|
||||
* Uniswap Labs also under MIT license.
|
||||
*/
|
||||
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
|
||||
@ -208,7 +209,7 @@ library Math {
|
||||
}
|
||||
|
||||
/**
|
||||
* @notice Calculates x * y / denominator with full precision, following the selected rounding direction.
|
||||
* @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
|
||||
*/
|
||||
function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
|
||||
uint256 result = mulDiv(x, y, denominator);
|
||||
@ -218,6 +219,62 @@ library Math {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
|
||||
*
|
||||
* If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0.
|
||||
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
|
||||
*
|
||||
* If the input value is not inversible, 0 is returned.
|
||||
*/
|
||||
function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
|
||||
unchecked {
|
||||
if (n == 0) return 0;
|
||||
|
||||
// The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
|
||||
// Used to compute integers x and y such that: ax + ny = gcd(a, n).
|
||||
// When the gcd is 1, then the inverse of a modulo n exists and it's x.
|
||||
// ax + ny = 1
|
||||
// ax = 1 + (-y)n
|
||||
// ax ≡ 1 (mod n) # x is the inverse of a modulo n
|
||||
|
||||
// If the remainder is 0 the gcd is n right away.
|
||||
uint256 remainder = a % n;
|
||||
uint256 gcd = n;
|
||||
|
||||
// Therefore the initial coefficients are:
|
||||
// ax + ny = gcd(a, n) = n
|
||||
// 0a + 1n = n
|
||||
int256 x = 0;
|
||||
int256 y = 1;
|
||||
|
||||
while (remainder != 0) {
|
||||
uint256 quotient = gcd / remainder;
|
||||
|
||||
(gcd, remainder) = (
|
||||
// The old remainder is the next gcd to try.
|
||||
remainder,
|
||||
// Compute the next remainder.
|
||||
// Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
|
||||
// where gcd is at most n (capped to type(uint256).max)
|
||||
gcd - remainder * quotient
|
||||
);
|
||||
|
||||
(x, y) = (
|
||||
// Increment the coefficient of a.
|
||||
y,
|
||||
// Decrement the coefficient of n.
|
||||
// Can overflow, but the result is casted to uint256 so that the
|
||||
// next value of y is "wrapped around" to a value between 0 and n - 1.
|
||||
x - y * int256(quotient)
|
||||
);
|
||||
}
|
||||
|
||||
if (gcd != 1) return 0; // No inverse exists.
|
||||
return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
|
||||
* towards zero.
|
||||
@ -258,7 +315,7 @@ library Math {
|
||||
}
|
||||
|
||||
/**
|
||||
* @notice Calculates sqrt(a), following the selected rounding direction.
|
||||
* @dev Calculates sqrt(a), following the selected rounding direction.
|
||||
*/
|
||||
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
|
||||
unchecked {
|
||||
|
||||
@ -55,6 +55,41 @@ contract MathTest is Test {
|
||||
return value * value < ref;
|
||||
}
|
||||
|
||||
// INV
|
||||
function testInvMod(uint256 value, uint256 p) public {
|
||||
_testInvMod(value, p, true);
|
||||
}
|
||||
|
||||
function testInvMod2(uint256 seed) public {
|
||||
uint256 p = 2; // prime
|
||||
_testInvMod(bound(seed, 1, p - 1), p, false);
|
||||
}
|
||||
|
||||
function testInvMod17(uint256 seed) public {
|
||||
uint256 p = 17; // prime
|
||||
_testInvMod(bound(seed, 1, p - 1), p, false);
|
||||
}
|
||||
|
||||
function testInvMod65537(uint256 seed) public {
|
||||
uint256 p = 65537; // prime
|
||||
_testInvMod(bound(seed, 1, p - 1), p, false);
|
||||
}
|
||||
|
||||
function testInvModP256(uint256 seed) public {
|
||||
uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime
|
||||
_testInvMod(bound(seed, 1, p - 1), p, false);
|
||||
}
|
||||
|
||||
function _testInvMod(uint256 value, uint256 p, bool allowZero) private {
|
||||
uint256 inverse = Math.invMod(value, p);
|
||||
if (inverse != 0) {
|
||||
assertEq(mulmod(value, inverse, p), 1);
|
||||
assertLt(inverse, p);
|
||||
} else {
|
||||
assertTrue(allowZero);
|
||||
}
|
||||
}
|
||||
|
||||
// LOG2
|
||||
function testLog2(uint256 input, uint8 r) public {
|
||||
Math.Rounding rounding = _asRounding(r);
|
||||
|
||||
@ -5,6 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
|
||||
|
||||
const { Rounding } = require('../../helpers/enums');
|
||||
const { min, max } = require('../../helpers/math');
|
||||
const { randomArray, generators } = require('../../helpers/random');
|
||||
|
||||
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
|
||||
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
|
||||
@ -298,6 +299,43 @@ describe('Math', function () {
|
||||
});
|
||||
});
|
||||
|
||||
describe('invMod', function () {
|
||||
for (const factors of [
|
||||
[0n],
|
||||
[1n],
|
||||
[2n],
|
||||
[17n],
|
||||
[65537n],
|
||||
[0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn],
|
||||
[3n, 5n],
|
||||
[3n, 7n],
|
||||
[47n, 53n],
|
||||
]) {
|
||||
const p = factors.reduce((acc, f) => acc * f, 1n);
|
||||
|
||||
describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () {
|
||||
it('trying to inverse 0 returns 0', async function () {
|
||||
expect(await this.mock.$invMod(0, p)).to.equal(0n);
|
||||
expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p
|
||||
});
|
||||
|
||||
if (p != 0) {
|
||||
for (const value of randomArray(generators.uint256, 16)) {
|
||||
const isInversible = factors.every(f => value % f);
|
||||
it(`trying to inverse ${value}`, async function () {
|
||||
const result = await this.mock.$invMod(value, p);
|
||||
if (isInversible) {
|
||||
expect((value * result) % p).to.equal(1n);
|
||||
} else {
|
||||
expect(result).to.equal(0n);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('sqrt', function () {
|
||||
it('rounds down', async function () {
|
||||
for (const rounding of RoundingDown) {
|
||||
|
||||
Reference in New Issue
Block a user