Add Math.modExp and a Panic library (#3298)
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com> Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
@ -3,15 +3,13 @@
|
||||
|
||||
pragma solidity ^0.8.20;
|
||||
|
||||
import {Address} from "../Address.sol";
|
||||
import {Panic} from "../Panic.sol";
|
||||
|
||||
/**
|
||||
* @dev Standard math utilities missing in the Solidity language.
|
||||
*/
|
||||
library Math {
|
||||
/**
|
||||
* @dev Muldiv operation overflow.
|
||||
*/
|
||||
error MathOverflowedMulDiv();
|
||||
|
||||
enum Rounding {
|
||||
Floor, // Toward negative infinity
|
||||
Ceil, // Toward positive infinity
|
||||
@ -107,7 +105,7 @@ library Math {
|
||||
function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) {
|
||||
if (b == 0) {
|
||||
// Guarantee the same behavior as in a regular Solidity division.
|
||||
return a / b;
|
||||
Panic.panic(Panic.DIVISION_BY_ZERO);
|
||||
}
|
||||
|
||||
// The following calculation ensures accurate ceiling division without overflow.
|
||||
@ -149,7 +147,7 @@ library Math {
|
||||
|
||||
// Make sure the result is less than 2^256. Also prevents denominator == 0.
|
||||
if (denominator <= prod1) {
|
||||
revert MathOverflowedMulDiv();
|
||||
Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////
|
||||
@ -226,6 +224,9 @@ library Math {
|
||||
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
|
||||
*
|
||||
* If the input value is not inversible, 0 is returned.
|
||||
*
|
||||
* NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Ferma's little theorem and get the
|
||||
* inverse using `Math.modExp(a, n - 2, n)`.
|
||||
*/
|
||||
function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
|
||||
unchecked {
|
||||
@ -275,6 +276,68 @@ library Math {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m)
|
||||
*
|
||||
* Requirements:
|
||||
* - modulus can't be zero
|
||||
* - underlying staticcall to precompile must succeed
|
||||
*
|
||||
* IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make
|
||||
* sure the chain you're using it on supports the precompiled contract for modular exponentiation
|
||||
* at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise,
|
||||
* the underlying function will succeed given the lack of a revert, but the result may be incorrectly
|
||||
* interpreted as 0.
|
||||
*/
|
||||
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
|
||||
(bool success, uint256 result) = tryModExp(b, e, m);
|
||||
if (!success) {
|
||||
if (m == 0) {
|
||||
Panic.panic(Panic.DIVISION_BY_ZERO);
|
||||
} else {
|
||||
revert Address.FailedInnerCall();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m).
|
||||
* It includes a success flag indicating if the operation succeeded. Operation will be marked has failed if trying
|
||||
* to operate modulo 0 or if the underlying precompile reverted.
|
||||
*
|
||||
* IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain
|
||||
* you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in
|
||||
* https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack
|
||||
* of a revert, but the result may be incorrectly interpreted as 0.
|
||||
*/
|
||||
function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) {
|
||||
if (m == 0) return (false, 0);
|
||||
/// @solidity memory-safe-assembly
|
||||
assembly {
|
||||
let ptr := mload(0x40)
|
||||
// | Offset | Content | Content (Hex) |
|
||||
// |-----------|------------|--------------------------------------------------------------------|
|
||||
// | 0x00:0x1f | size of b | 0x0000000000000000000000000000000000000000000000000000000000000020 |
|
||||
// | 0x20:0x3f | size of e | 0x0000000000000000000000000000000000000000000000000000000000000020 |
|
||||
// | 0x40:0x5f | size of m | 0x0000000000000000000000000000000000000000000000000000000000000020 |
|
||||
// | 0x60:0x7f | value of b | 0x<.............................................................b> |
|
||||
// | 0x80:0x9f | value of e | 0x<.............................................................e> |
|
||||
// | 0xa0:0xbf | value of m | 0x<.............................................................m> |
|
||||
mstore(ptr, 0x20)
|
||||
mstore(add(ptr, 0x20), 0x20)
|
||||
mstore(add(ptr, 0x40), 0x20)
|
||||
mstore(add(ptr, 0x60), b)
|
||||
mstore(add(ptr, 0x80), e)
|
||||
mstore(add(ptr, 0xa0), m)
|
||||
|
||||
// Given the result < m, it's guaranteed to fit in 32 bytes,
|
||||
// so we can use the memory scratch space located at offset 0.
|
||||
success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20)
|
||||
result := mload(0x00)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
|
||||
* towards zero.
|
||||
|
||||
Reference in New Issue
Block a user