Add bytes memory version of Math.modExp (#4893)

Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com>
This commit is contained in:
Ernesto García
2024-02-14 03:06:34 -06:00
committed by GitHub
parent ae1bafcb48
commit 4e7e6e54da
5 changed files with 180 additions and 36 deletions

View File

@ -3,7 +3,6 @@
pragma solidity ^0.8.20;
import {Address} from "../Address.sol";
import {Panic} from "../Panic.sol";
import {SafeCast} from "./SafeCast.sol";
@ -289,11 +288,7 @@ library Math {
function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) {
(bool success, uint256 result) = tryModExp(b, e, m);
if (!success) {
if (m == 0) {
Panic.panic(Panic.DIVISION_BY_ZERO);
} else {
revert Address.FailedInnerCall();
}
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}
@ -335,6 +330,57 @@ library Math {
}
}
/**
* @dev Variant of {modExp} that supports inputs of arbitrary length.
*/
function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) {
(bool success, bytes memory result) = tryModExp(b, e, m);
if (!success) {
Panic.panic(Panic.DIVISION_BY_ZERO);
}
return result;
}
/**
* @dev Variant of {tryModExp} that supports inputs of arbitrary length.
*/
function tryModExp(
bytes memory b,
bytes memory e,
bytes memory m
) internal view returns (bool success, bytes memory result) {
if (_zeroBytes(m)) return (false, new bytes(0));
uint256 mLen = m.length;
// Encode call args in result and move the free memory pointer
result = abi.encodePacked(b.length, e.length, mLen, b, e, m);
/// @solidity memory-safe-assembly
assembly {
let dataPtr := add(result, 0x20)
// Write result on top of args to avoid allocating extra memory.
success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen)
// Overwrite the length.
// result.length > returndatasize() is guaranteed because returndatasize() == m.length
mstore(result, mLen)
// Set the memory pointer after the returned data.
mstore(0x40, add(dataPtr, mLen))
}
}
/**
* @dev Returns whether the provided byte array is zero.
*/
function _zeroBytes(bytes memory byteArray) private pure returns (bool) {
for (uint256 i = 0; i < byteArray.length; ++i) {
if (byteArray[i] != 0) {
return false;
}
}
return true;
}
/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.