diff --git a/.changeset/shiny-poets-whisper.md b/.changeset/shiny-poets-whisper.md new file mode 100644 index 000000000..cdef23914 --- /dev/null +++ b/.changeset/shiny-poets-whisper.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Math`: Add `modExp` function that exposes the `EIP-198` precompile. diff --git a/.changeset/silver-swans-promise.md b/.changeset/silver-swans-promise.md new file mode 100644 index 000000000..1d2ff2e9e --- /dev/null +++ b/.changeset/silver-swans-promise.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Panic`: Add a library for reverting with panic codes. diff --git a/.changeset/smart-bugs-switch.md b/.changeset/smart-bugs-switch.md new file mode 100644 index 000000000..6bdea5f2d --- /dev/null +++ b/.changeset/smart-bugs-switch.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Math`: MathOverflowedMulDiv error was replaced with native panic codes. diff --git a/contracts/mocks/_import.sol b/contracts/mocks/_import.sol new file mode 100644 index 000000000..726b084c5 --- /dev/null +++ b/contracts/mocks/_import.sol @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +import {Address} from "../utils/Address.sol"; +import {Arrays} from "../utils/Arrays.sol"; +import {Base64} from "../utils/Base64.sol"; +import {BitMaps} from "../utils/structs/BitMaps.sol"; +import {Checkpoints} from "../utils/structs/Checkpoints.sol"; +import {Context} from "../utils/Context.sol"; +import {Create2} from "../utils/Create2.sol"; +import {DoubleEndedQueue} from "../utils/structs/DoubleEndedQueue.sol"; +import {ECDSA} from "../utils/cryptography/ECDSA.sol"; +import {EIP712} from "../utils/cryptography/EIP712.sol"; +import {EnumerableMap} from "../utils/structs/EnumerableMap.sol"; +import {EnumerableSet} from "../utils/structs/EnumerableSet.sol"; +import {ERC165} from "../utils/introspection/ERC165.sol"; +import {ERC165Checker} from "../utils/introspection/ERC165Checker.sol"; +import {IERC165} from "../utils/introspection/IERC165.sol"; +import {Math} from "../utils/math/Math.sol"; +import {MerkleProof} from "../utils/cryptography/MerkleProof.sol"; +import {MessageHashUtils} from "../utils/cryptography/MessageHashUtils.sol"; +import {Multicall} from "../utils/Multicall.sol"; +import {Nonces} from "../utils/Nonces.sol"; +import {Panic} from "../utils/Panic.sol"; +import {Pausable} from "../utils/Pausable.sol"; +import {ReentrancyGuard} from "../utils/ReentrancyGuard.sol"; +import {SafeCast} from "../utils/math/SafeCast.sol"; +import {ShortStrings} from "../utils/ShortStrings.sol"; +import {SignatureChecker} from "../utils/cryptography/SignatureChecker.sol"; +import {SignedMath} from "../utils/math/SignedMath.sol"; +import {StorageSlot} from "../utils/StorageSlot.sol"; +import {Strings} from "../utils/Strings.sol"; +import {Time} from "../utils/types/Time.sol"; + +abstract contract ExposeImports { + // This will be transpiled, causing all the imports above to be transpiled when running the upgradeable tests. + // This trick is necessary for testing libraries such as Panic.sol (which are not imported by any other transpiled + // contracts and would otherwise not be exposed). +} diff --git a/contracts/utils/Panic.sol b/contracts/utils/Panic.sol new file mode 100644 index 000000000..4561c7a13 --- /dev/null +++ b/contracts/utils/Panic.sol @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +/** + * @dev Helper library for emitting standardized panic codes. + * + * ```solidity + * contract Example { + * using Panic for uint256; + * + * // Use any of the declared internal constants + * function foo() { Panic.GENERIC.panic(); } + * + * // Alternatively + * function foo() { Panic.panic(Panic.GENERIC); } + * } + * ``` + * + * Follows the list from libsolutil: https://github.com/ethereum/solidity/blob/v0.8.24/libsolutil/ErrorCodes.h + */ +// slither-disable-next-line unused-state +library Panic { + /// @dev generic / unspecified error + uint256 internal constant GENERIC = 0x00; + /// @dev used by the assert() builtin + uint256 internal constant ASSERT = 0x01; + /// @dev arithmetic underflow or overflow + uint256 internal constant UNDER_OVERFLOW = 0x11; + /// @dev division or modulo by zero + uint256 internal constant DIVISION_BY_ZERO = 0x12; + /// @dev enum conversion error + uint256 internal constant ENUM_CONVERSION_ERROR = 0x21; + /// @dev invalid encoding in storage + uint256 internal constant STORAGE_ENCODING_ERROR = 0x22; + /// @dev empty array pop + uint256 internal constant EMPTY_ARRAY_POP = 0x31; + /// @dev array out of bounds access + uint256 internal constant ARRAY_OUT_OF_BOUNDS = 0x32; + /// @dev resource error (too large allocation or too large array) + uint256 internal constant RESOURCE_ERROR = 0x41; + /// @dev calling invalid internal function + uint256 internal constant INVALID_INTERNAL_FUNCTION = 0x51; + + /// @dev Reverts with a panic code. Recommended to use with + /// the internal constants with predefined codes. + function panic(uint256 code) internal pure { + /// @solidity memory-safe-assembly + assembly { + mstore(0x00, shl(0xe0, 0x4e487b71)) + mstore(0x04, code) + revert(0x00, 0x24) + } + } +} diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index a826dfd96..c2d419eb9 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -3,15 +3,13 @@ pragma solidity ^0.8.20; +import {Address} from "../Address.sol"; +import {Panic} from "../Panic.sol"; + /** * @dev Standard math utilities missing in the Solidity language. */ library Math { - /** - * @dev Muldiv operation overflow. - */ - error MathOverflowedMulDiv(); - enum Rounding { Floor, // Toward negative infinity Ceil, // Toward positive infinity @@ -107,7 +105,7 @@ library Math { function ceilDiv(uint256 a, uint256 b) internal pure returns (uint256) { if (b == 0) { // Guarantee the same behavior as in a regular Solidity division. - return a / b; + Panic.panic(Panic.DIVISION_BY_ZERO); } // The following calculation ensures accurate ceiling division without overflow. @@ -149,7 +147,7 @@ library Math { // Make sure the result is less than 2^256. Also prevents denominator == 0. if (denominator <= prod1) { - revert MathOverflowedMulDiv(); + Panic.panic(denominator == 0 ? Panic.DIVISION_BY_ZERO : Panic.UNDER_OVERFLOW); } /////////////////////////////////////////////// @@ -226,6 +224,9 @@ library Math { * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible. * * If the input value is not inversible, 0 is returned. + * + * NOTE: If you know for sure that n is (big) a prime, it may be cheaper to use Ferma's little theorem and get the + * inverse using `Math.modExp(a, n - 2, n)`. */ function invMod(uint256 a, uint256 n) internal pure returns (uint256) { unchecked { @@ -275,6 +276,68 @@ library Math { } } + /** + * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m) + * + * Requirements: + * - modulus can't be zero + * - underlying staticcall to precompile must succeed + * + * IMPORTANT: The result is only valid if the underlying call succeeds. When using this function, make + * sure the chain you're using it on supports the precompiled contract for modular exponentiation + * at address 0x05 as specified in https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, + * the underlying function will succeed given the lack of a revert, but the result may be incorrectly + * interpreted as 0. + */ + function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) { + (bool success, uint256 result) = tryModExp(b, e, m); + if (!success) { + if (m == 0) { + Panic.panic(Panic.DIVISION_BY_ZERO); + } else { + revert Address.FailedInnerCall(); + } + } + return result; + } + + /** + * @dev Returns the modular exponentiation of the specified base, exponent and modulus (b ** e % m). + * It includes a success flag indicating if the operation succeeded. Operation will be marked has failed if trying + * to operate modulo 0 or if the underlying precompile reverted. + * + * IMPORTANT: The result is only valid if the success flag is true. When using this function, make sure the chain + * you're using it on supports the precompiled contract for modular exponentiation at address 0x05 as specified in + * https://eips.ethereum.org/EIPS/eip-198[EIP-198]. Otherwise, the underlying function will succeed given the lack + * of a revert, but the result may be incorrectly interpreted as 0. + */ + function tryModExp(uint256 b, uint256 e, uint256 m) internal view returns (bool success, uint256 result) { + if (m == 0) return (false, 0); + /// @solidity memory-safe-assembly + assembly { + let ptr := mload(0x40) + // | Offset | Content | Content (Hex) | + // |-----------|------------|--------------------------------------------------------------------| + // | 0x00:0x1f | size of b | 0x0000000000000000000000000000000000000000000000000000000000000020 | + // | 0x20:0x3f | size of e | 0x0000000000000000000000000000000000000000000000000000000000000020 | + // | 0x40:0x5f | size of m | 0x0000000000000000000000000000000000000000000000000000000000000020 | + // | 0x60:0x7f | value of b | 0x<.............................................................b> | + // | 0x80:0x9f | value of e | 0x<.............................................................e> | + // | 0xa0:0xbf | value of m | 0x<.............................................................m> | + mstore(ptr, 0x20) + mstore(add(ptr, 0x20), 0x20) + mstore(add(ptr, 0x40), 0x20) + mstore(add(ptr, 0x60), b) + mstore(add(ptr, 0x80), e) + mstore(add(ptr, 0xa0), m) + + // Given the result < m, it's guaranteed to fit in 32 bytes, + // so we can use the memory scratch space located at offset 0. + success := staticcall(gas(), 0x05, ptr, 0xc0, 0x00, 0x20) + result := mload(0x00) + } + } + /** * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded * towards zero. diff --git a/scripts/generate/templates/Checkpoints.t.js b/scripts/generate/templates/Checkpoints.t.js index d21beb53e..7e6a738db 100644 --- a/scripts/generate/templates/Checkpoints.t.js +++ b/scripts/generate/templates/Checkpoints.t.js @@ -26,14 +26,14 @@ function _bound${capitalize(opts.keyTypeName)}( ${opts.keyTypeName} x, ${opts.keyTypeName} min, ${opts.keyTypeName} max -) internal view returns (${opts.keyTypeName}) { +) internal pure returns (${opts.keyTypeName}) { return SafeCast.to${capitalize(opts.keyTypeName)}(bound(uint256(x), uint256(min), uint256(max))); } function _prepareKeys( ${opts.keyTypeName}[] memory keys, ${opts.keyTypeName} maxSpread -) internal view { +) internal pure { ${opts.keyTypeName} lastKey = 0; for (uint256 i = 0; i < keys.length; ++i) { ${opts.keyTypeName} key = _bound${capitalize(opts.keyTypeName)}(keys[i], lastKey, lastKey + maxSpread); diff --git a/test/utils/Panic.test.js b/test/utils/Panic.test.js new file mode 100644 index 000000000..cdd492601 --- /dev/null +++ b/test/utils/Panic.test.js @@ -0,0 +1,37 @@ +const { ethers } = require('hardhat'); +const { expect } = require('chai'); +const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); +const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); + +async function fixture() { + return { mock: await ethers.deployContract('$Panic') }; +} + +describe('Panic', function () { + beforeEach(async function () { + Object.assign(this, await loadFixture(fixture)); + }); + + for (const [name, code] of Object.entries({ + GENERIC: 0x0, + ASSERT: PANIC_CODES.ASSERTION_ERROR, + UNDER_OVERFLOW: PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW, + DIVISION_BY_ZERO: PANIC_CODES.DIVISION_BY_ZERO, + ENUM_CONVERSION_ERROR: PANIC_CODES.ENUM_CONVERSION_OUT_OF_BOUNDS, + STORAGE_ENCODING_ERROR: PANIC_CODES.INCORRECTLY_ENCODED_STORAGE_BYTE_ARRAY, + EMPTY_ARRAY_POP: PANIC_CODES.POP_ON_EMPTY_ARRAY, + ARRAY_OUT_OF_BOUNDS: PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS, + RESOURCE_ERROR: PANIC_CODES.TOO_MUCH_MEMORY_ALLOCATED, + INVALID_INTERNAL_FUNCTION: PANIC_CODES.ZERO_INITIALIZED_VARIABLE, + })) { + describe(`${name} (${ethers.toBeHex(code)})`, function () { + it('exposes panic code as constant', async function () { + expect(await this.mock.getFunction(`$${name}`)()).to.equal(code); + }); + + it('reverts with panic when called', async function () { + await expect(this.mock.$panic(code)).to.be.revertedWithPanic(code); + }); + }); + } +}); diff --git a/test/utils/math/Math.t.sol b/test/utils/math/Math.t.sol index 75d28041d..7b49e8a88 100644 --- a/test/utils/math/Math.t.sol +++ b/test/utils/math/Math.t.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.20; -import {Test} from "forge-std/Test.sol"; +import {Test, stdError} from "forge-std/Test.sol"; import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; @@ -186,7 +186,7 @@ contract MathTest is Test { // Full precision for q * d (uint256 qdHi, uint256 qdLo) = _mulHighLow(q, d); // Add remainder of x * y / d (computed as rem = (x * y % d)) - (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, _mulmod(x, y, d)); + (uint256 qdRemLo, uint256 c) = _addCarry(qdLo, mulmod(x, y, d)); uint256 qdRemHi = qdHi + c; // Full precision check that x * y = q * d + rem @@ -201,14 +201,42 @@ contract MathTest is Test { vm.assume(xyHi >= d); // we are outside the scope of {testMulDiv}, we expect muldiv to revert - try this.muldiv(x, y, d) returns (uint256) { - fail(); - } catch {} + vm.expectRevert(d == 0 ? stdError.divisionError : stdError.arithmeticError); + Math.mulDiv(x, y, d); } - // External call - function muldiv(uint256 x, uint256 y, uint256 d) external pure returns (uint256) { - return Math.mulDiv(x, y, d); + // MOD EXP + function testModExp(uint256 b, uint256 e, uint256 m) public { + if (m == 0) { + vm.expectRevert(stdError.divisionError); + } + uint256 result = Math.modExp(b, e, m); + assertLt(result, m); + assertEq(result, _nativeModExp(b, e, m)); + } + + function testTryModExp(uint256 b, uint256 e, uint256 m) public { + (bool success, uint256 result) = Math.tryModExp(b, e, m); + assertEq(success, m != 0); + if (success) { + assertLt(result, m); + assertEq(result, _nativeModExp(b, e, m)); + } else { + assertEq(result, 0); + } + } + + function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) { + if (m == 1) return 0; + uint256 r = 1; + while (e > 0) { + if (e % 2 > 0) { + r = mulmod(r, b, m); + } + b = mulmod(b, b, m); + e >>= 1; + } + return r; } // Helpers @@ -217,12 +245,6 @@ contract MathTest is Test { return Math.Rounding(r); } - function _mulmod(uint256 x, uint256 y, uint256 z) private pure returns (uint256 r) { - assembly { - r := mulmod(x, y, z) - } - } - function _mulHighLow(uint256 x, uint256 y) private pure returns (uint256 high, uint256 low) { (uint256 x0, uint256 x1) = (x & type(uint128).max, x >> 128); (uint256 y0, uint256 y1) = (y & type(uint128).max, y >> 128); diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index abf43f073..2762fcc57 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -141,6 +141,24 @@ describe('Math', function () { }); }); + describe('tryModExp', function () { + it('is correctly returning true and calculating modulus', async function () { + const base = 3n; + const exponent = 200n; + const modulus = 50n; + + expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]); + }); + + it('is correctly returning false when modulus is 0', async function () { + const base = 3n; + const exponent = 200n; + const modulus = 0n; + + expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]); + }); + }); + describe('max', function () { it('is correctly detected in both position', async function () { await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n)); @@ -222,7 +240,7 @@ describe('Math', function () { }); }); - describe('muldiv', function () { + describe('mulDiv', function () { it('divide by 0', async function () { const a = 1n; const b = 1n; @@ -234,9 +252,8 @@ describe('Math', function () { const a = 5n; const b = ethers.MaxUint256; const c = 2n; - await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithCustomError( - this.mock, - 'MathOverflowedMulDiv', + await expect(this.mock.$mulDiv(a, b, c, Rounding.Floor)).to.be.revertedWithPanic( + PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW, ); }); @@ -336,6 +353,24 @@ describe('Math', function () { } }); + describe('modExp', function () { + it('is correctly calculating modulus', async function () { + const base = 3n; + const exponent = 200n; + const modulus = 50n; + + expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus); + }); + + it('is correctly reverting when modulus is zero', async function () { + const base = 3n; + const exponent = 200n; + const modulus = 0n; + + await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO); + }); + }); + describe('sqrt', function () { it('rounds down', async function () { for (const rounding of RoundingDown) { diff --git a/test/utils/structs/Checkpoints.t.sol b/test/utils/structs/Checkpoints.t.sol index 7bdbcfddf..72d209f4c 100644 --- a/test/utils/structs/Checkpoints.t.sol +++ b/test/utils/structs/Checkpoints.t.sol @@ -17,11 +17,11 @@ contract CheckpointsTrace224Test is Test { Checkpoints.Trace224 internal _ckpts; // helpers - function _boundUint32(uint32 x, uint32 min, uint32 max) internal view returns (uint32) { + function _boundUint32(uint32 x, uint32 min, uint32 max) internal pure returns (uint32) { return SafeCast.toUint32(bound(uint256(x), uint256(min), uint256(max))); } - function _prepareKeys(uint32[] memory keys, uint32 maxSpread) internal view { + function _prepareKeys(uint32[] memory keys, uint32 maxSpread) internal pure { uint32 lastKey = 0; for (uint256 i = 0; i < keys.length; ++i) { uint32 key = _boundUint32(keys[i], lastKey, lastKey + maxSpread); @@ -125,11 +125,11 @@ contract CheckpointsTrace208Test is Test { Checkpoints.Trace208 internal _ckpts; // helpers - function _boundUint48(uint48 x, uint48 min, uint48 max) internal view returns (uint48) { + function _boundUint48(uint48 x, uint48 min, uint48 max) internal pure returns (uint48) { return SafeCast.toUint48(bound(uint256(x), uint256(min), uint256(max))); } - function _prepareKeys(uint48[] memory keys, uint48 maxSpread) internal view { + function _prepareKeys(uint48[] memory keys, uint48 maxSpread) internal pure { uint48 lastKey = 0; for (uint256 i = 0; i < keys.length; ++i) { uint48 key = _boundUint48(keys[i], lastKey, lastKey + maxSpread); @@ -233,11 +233,11 @@ contract CheckpointsTrace160Test is Test { Checkpoints.Trace160 internal _ckpts; // helpers - function _boundUint96(uint96 x, uint96 min, uint96 max) internal view returns (uint96) { + function _boundUint96(uint96 x, uint96 min, uint96 max) internal pure returns (uint96) { return SafeCast.toUint96(bound(uint256(x), uint256(min), uint256(max))); } - function _prepareKeys(uint96[] memory keys, uint96 maxSpread) internal view { + function _prepareKeys(uint96[] memory keys, uint96 maxSpread) internal pure { uint96 lastKey = 0; for (uint256 i = 0; i < keys.length; ++i) { uint96 key = _boundUint96(keys[i], lastKey, lastKey + maxSpread);