Add saturating (unsigned) math operations and optimize try operations (#5527)

This commit is contained in:
Hadrien Croubois
2025-02-27 10:03:54 +01:00
committed by GitHub
parent 506e1f827a
commit a9b1f58b00
3 changed files with 109 additions and 16 deletions

View File

@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---
`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.

View File

@ -51,8 +51,8 @@ library Math {
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
uint256 c = a + b;
if (c < a) return (false, 0);
return (true, c);
success = c >= a;
result = c * SafeCast.toUint(success);
}
}
@ -61,8 +61,9 @@ library Math {
*/
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b > a) return (false, 0);
return (true, a - b);
uint256 c = a - b;
success = c <= a;
result = c * SafeCast.toUint(success);
}
}
@ -71,13 +72,14 @@ library Math {
*/
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
// Gas optimization: this is cheaper than requiring 'a' not being zero, but the
// benefit is lost if 'b' is also tested.
// See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
if (a == 0) return (true, 0);
uint256 c = a * b;
if (c / a != b) return (false, 0);
return (true, c);
assembly ("memory-safe") {
// Only true when the multiplication doesn't overflow
// (c / a == b) || (a == 0)
success := or(eq(div(c, a), b), iszero(a))
}
// equivalent to: success ? c : 0
result = c * SafeCast.toUint(success);
}
}
@ -86,8 +88,11 @@ library Math {
*/
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b == 0) return (false, 0);
return (true, a / b);
success = b > 0;
assembly ("memory-safe") {
// The `DIV` opcode returns zero when the denominator is 0.
result := div(a, b)
}
}
}
@ -96,10 +101,37 @@ library Math {
*/
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b == 0) return (false, 0);
return (true, a % b);
success = b > 0;
assembly ("memory-safe") {
// The `MOD` opcode returns zero when the denominator is 0.
result := mod(a, b)
}
}
}
/**
* @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
*/
function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
(bool success, uint256 result) = tryAdd(a, b);
return ternary(success, result, type(uint256).max);
}
/**
* @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
*/
function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
(, uint256 result) = trySub(a, b);
return result;
}
/**
* @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
*/
function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
(bool success, uint256 result) = tryMul(a, b);
return ternary(success, result, type(uint256).max);
}
/**
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
@ -192,7 +224,7 @@ library Math {
// Make division exact by subtracting the remainder from [high low].
uint256 remainder;
assembly {
assembly ("memory-safe") {
// Compute remainder using mulmod.
remainder := mulmod(x, y, denominator)
@ -205,7 +237,7 @@ library Math {
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
uint256 twos = denominator & (0 - denominator);
assembly {
assembly ("memory-safe") {
// Divide denominator by twos.
denominator := div(denominator, twos)

View File

@ -168,6 +168,62 @@ describe('Math', function () {
});
});
describe('saturatingAdd', function () {
it('adds correctly', async function () {
const a = 5678n;
const b = 1234n;
await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
});
it('bounds on addition overflow', async function () {
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
ethers.MaxUint256,
);
});
});
describe('saturatingSub', function () {
it('subtracts correctly', async function () {
const a = 5678n;
const b = 1234n;
await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
});
it('bounds on subtraction overflow', async function () {
await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
});
});
describe('saturatingMul', function () {
it('multiplies correctly', async function () {
const a = 1234n;
const b = 5678n;
await testCommutative(this.mock.$saturatingMul, a, b, a * b);
});
it('multiplies by zero correctly', async function () {
const a = 0n;
const b = 5678n;
await testCommutative(this.mock.$saturatingMul, a, b, 0n);
});
it('bounds on multiplication overflow', async function () {
const a = ethers.MaxUint256;
const b = 2n;
await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
});
});
describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));