Add Math.modExp and a Panic library (#3298)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
Mihir Wadekar
2024-02-02 09:40:00 -08:00
committed by GitHub
parent cc431f53e0
commit 192e873fcb
11 changed files with 300 additions and 33 deletions

View File

@ -3,15 +3,13 @@
pragma solidity ^0.8.20;
import {Address} from "../Address.sol";
import {Panic} from "../Panic.sol";
/**
* @dev Standard math utilities missing in the Solidity language.
*/
library Math {
/**
* @dev Muldiv operation overflow.
*/
error MathOverflowedMulDiv();
enum Rounding {
Floor, // Toward negative infinity
Ceil, // Toward positive infinity
@ -107,7 +105,7 @@ library Math {
function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) {
if (b == 0) {
// Guarantee the same behavior as in a regular Solidity division.
return a / b;
Panic.panic(Panic.DIVISION_BY_ZERO);
}
// The following calculation ensures accurate ceiling division without overflow.
@ -149,7 +147,7 @@ library Math {
// Make sure the result is less than 2^256. Also prevents denominator == 0.
if (denominator <= prod1) {
revert MathOverflowedMulDiv();
Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW);
}
///////////////////////////////////////////////
@ -226,6 +224,9 @@ library Math {
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
*
* If the input value is not inversible, 0 is returned.
*
* NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Ferma's little theorem and get the
* inverse using `Math.modExp(a, n - 2, n)`.
*/
function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
unchecked {
@ -275,6 +276,68 @@ library Math {
}
}
/**
* @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m)
*
* Requirements:
* - modulus can't be zero
* - underlying staticcall to precompile must succeed
*
* IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make
* sure the chain you're using it on supports the precompiled contract for modular exponentiation
* at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise,
* the underlying function will succeed given the lack of a revert, but the result may be incorrectly
* interpreted as 0.
*/
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
(bool success, uint256 result) = tryModExp(b, e, m);
if (!success) {
if (m == 0) {
Panic.panic(Panic.DIVISION_BY_ZERO);
} else {
revert Address.FailedInnerCall();
}
}
return result;
}
/**
* @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m).
* It includes a success flag indicating if the operation succeeded. Operation will be marked has failed if trying
* to operate modulo 0 or if the underlying precompile reverted.
*
* IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain
* you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in
* https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack
* of a revert, but the result may be incorrectly interpreted as 0.
*/
function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) {
if (m == 0) return (false, 0);
/// @solidity memory-safe-assembly
assembly {
let ptr := mload(0x40)
// | Offset | Content | Content (Hex) |
// |-----------|------------|--------------------------------------------------------------------|
// | 0x00:0x1f | size of b | 0x0000000000000000000000000000000000000000000000000000000000000020 |
// | 0x20:0x3f | size of e | 0x0000000000000000000000000000000000000000000000000000000000000020 |
// | 0x40:0x5f | size of m | 0x0000000000000000000000000000000000000000000000000000000000000020 |
// | 0x60:0x7f | value of b | 0x<.............................................................b> |
// | 0x80:0x9f | value of e | 0x<.............................................................e> |
// | 0xa0:0xbf | value of m | 0x<.............................................................m> |
mstore(ptr, 0x20)
mstore(add(ptr, 0x20), 0x20)
mstore(add(ptr, 0x40), 0x20)
mstore(add(ptr, 0x60), b)
mstore(add(ptr, 0x80), e)
mstore(add(ptr, 0xa0), m)
// Given the result < m, it's guaranteed to fit in 32 bytes,
// so we can use the memory scratch space located at offset 0.
success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20)
result := mload(0x00)
}
}
/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.