diff --git a/.changeset/nervous-pans-grow.md b/.changeset/nervous-pans-grow.md new file mode 100644 index 000000000..b86a075c6 --- /dev/null +++ b/.changeset/nervous-pans-grow.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': patch +--- + +`SafeCast`: Add `toUint(bool)` for operating on `bool` values as `uint256`. diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index c2d419eb9..be05506ec 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -5,6 +5,7 @@ pragma solidity ^0.8.20; import {Address} from "../Address.sol"; import {Panic} from "../Panic.sol"; +import {SafeCast} from "./SafeCast.sol"; /** * @dev Standard math utilities missing in the Solidity language. @@ -210,11 +211,7 @@ library Math { * @dev Calculates x * y / denominator with full precision, following the selected rounding direction. */ function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) { - uint256 result = mulDiv(x, y, denominator); - if (unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0) { - result += 1; - } - return result; + return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0); } /** @@ -383,7 +380,7 @@ library Math { function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = sqrt(a); - return result + (unsignedRoundsUp(rounding) && result * result < a ? 1 : 0); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && result * result < a); } } @@ -393,38 +390,37 @@ library Math { */ function log2(uint256 value) internal pure returns (uint256) { uint256 result = 0; + uint256 exp; unchecked { - if (value >> 128 > 0) { - value >>= 128; - result += 128; - } - if (value >> 64 > 0) { - value >>= 64; - result += 64; - } - if (value >> 32 > 0) { - value >>= 32; - result += 32; - } - if (value >> 16 > 0) { - value >>= 16; - result += 16; - } - if (value >> 8 > 0) { - value >>= 8; - result += 8; - } - if (value >> 4 > 0) { - value >>= 4; - result += 4; - } - if (value >> 2 > 0) { - value >>= 2; - result += 2; - } - if (value >> 1 > 0) { - result += 1; - } + exp = 128 * SafeCast.toUint(value > (1 << 128) - 1); + value >>= exp; + result += exp; + + exp = 64 * SafeCast.toUint(value > (1 << 64) - 1); + value >>= exp; + result += exp; + + exp = 32 * SafeCast.toUint(value > (1 << 32) - 1); + value >>= exp; + result += exp; + + exp = 16 * SafeCast.toUint(value > (1 << 16) - 1); + value >>= exp; + result += exp; + + exp = 8 * SafeCast.toUint(value > (1 << 8) - 1); + value >>= exp; + result += exp; + + exp = 4 * SafeCast.toUint(value > (1 << 4) - 1); + value >>= exp; + result += exp; + + exp = 2 * SafeCast.toUint(value > (1 << 2) - 1); + value >>= exp; + result += exp; + + result += SafeCast.toUint(value > 1); } return result; } @@ -436,7 +432,7 @@ library Math { function log2(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log2(value); - return result + (unsignedRoundsUp(rounding) && 1 << result < value ? 1 : 0); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << result < value); } } @@ -485,7 +481,7 @@ library Math { function log10(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log10(value); - return result + (unsignedRoundsUp(rounding) && 10 ** result < value ? 1 : 0); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 10 ** result < value); } } @@ -497,26 +493,25 @@ library Math { */ function log256(uint256 value) internal pure returns (uint256) { uint256 result = 0; + uint256 isGt; unchecked { - if (value >> 128 > 0) { - value >>= 128; - result += 16; - } - if (value >> 64 > 0) { - value >>= 64; - result += 8; - } - if (value >> 32 > 0) { - value >>= 32; - result += 4; - } - if (value >> 16 > 0) { - value >>= 16; - result += 2; - } - if (value >> 8 > 0) { - result += 1; - } + isGt = SafeCast.toUint(value > (1 << 128) - 1); + value >>= isGt * 128; + result += isGt * 16; + + isGt = SafeCast.toUint(value > (1 << 64) - 1); + value >>= isGt * 64; + result += isGt * 8; + + isGt = SafeCast.toUint(value > (1 << 32) - 1); + value >>= isGt * 32; + result += isGt * 4; + + isGt = SafeCast.toUint(value > (1 << 16) - 1); + value >>= isGt * 16; + result += isGt * 2; + + result += SafeCast.toUint(value > (1 << 8) - 1); } return result; } @@ -528,7 +523,7 @@ library Math { function log256(uint256 value, Rounding rounding) internal pure returns (uint256) { unchecked { uint256 result = log256(value); - return result + (unsignedRoundsUp(rounding) && 1 << (result << 3) < value ? 1 : 0); + return result + SafeCast.toUint(unsignedRoundsUp(rounding) && 1 << (result << 3) < value); } } diff --git a/contracts/utils/math/SafeCast.sol b/contracts/utils/math/SafeCast.sol index 0ed458b43..d8de2e17c 100644 --- a/contracts/utils/math/SafeCast.sol +++ b/contracts/utils/math/SafeCast.sol @@ -5,7 +5,7 @@ pragma solidity ^0.8.20; /** - * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow + * @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow * checks. * * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can @@ -1150,4 +1150,14 @@ library SafeCast { } return int256(value); } + + /** + * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. + */ + function toUint(bool b) internal pure returns (uint256 u) { + /// @solidity memory-safe-assembly + assembly { + u := iszero(iszero(b)) + } + } } diff --git a/scripts/generate/templates/SafeCast.js b/scripts/generate/templates/SafeCast.js index f1954a753..a10ee75c9 100644 --- a/scripts/generate/templates/SafeCast.js +++ b/scripts/generate/templates/SafeCast.js @@ -7,7 +7,7 @@ const header = `\ pragma solidity ^0.8.20; /** - * @dev Wrappers over Solidity's uintXX/intXX casting operators with added overflow + * @dev Wrappers over Solidity's uintXX/intXX/bool casting operators with added overflow * checks. * * Downcasting from uint256/int256 in Solidity does not revert on overflow. This can @@ -116,11 +116,23 @@ function toUint${length}(int${length} value) internal pure returns (uint${length } `; +const boolToUint = ` + /** + * @dev Cast a boolean (false or true) to a uint256 (0 or 1) with no jump. + */ + function toUint(bool b) internal pure returns (uint256 u) { + /// @solidity memory-safe-assembly + assembly { + u := iszero(iszero(b)) + } + } +`; + // GENERATE module.exports = format( header.trimEnd(), 'library SafeCast {', errors, - [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256)], + [...LENGTHS.map(toUintDownCast), toUint(256), ...LENGTHS.map(toIntDownCast), toInt(256), boolToUint], '}', ); diff --git a/test/utils/math/SafeCast.test.js b/test/utils/math/SafeCast.test.js index ecf55dc35..aa609faf0 100644 --- a/test/utils/math/SafeCast.test.js +++ b/test/utils/math/SafeCast.test.js @@ -146,4 +146,14 @@ describe('SafeCast', function () { .withArgs(ethers.MaxUint256); }); }); + + describe('toUint (bool)', function () { + it('toUint(false) should be 0', async function () { + expect(await this.mock.$toUint(false)).to.equal(0n); + }); + + it('toUint(true) should be 1', async function () { + expect(await this.mock.$toUint(true)).to.equal(1n); + }); + }); });