Add saturating (unsigned) math operations and optimize try operations (#5527)
This commit is contained in:
5
.changeset/fair-pumpkins-compete.md
Normal file
5
.changeset/fair-pumpkins-compete.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
'openzeppelin-solidity': minor
|
||||||
|
---
|
||||||
|
|
||||||
|
`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.
|
||||||
@ -51,8 +51,8 @@ library Math {
|
|||||||
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
||||||
unchecked {
|
unchecked {
|
||||||
uint256 c = a + b;
|
uint256 c = a + b;
|
||||||
if (c < a) return (false, 0);
|
success = c >= a;
|
||||||
return (true, c);
|
result = c * SafeCast.toUint(success);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,8 +61,9 @@ library Math {
|
|||||||
*/
|
*/
|
||||||
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
||||||
unchecked {
|
unchecked {
|
||||||
if (b > a) return (false, 0);
|
uint256 c = a - b;
|
||||||
return (true, a - b);
|
success = c <= a;
|
||||||
|
result = c * SafeCast.toUint(success);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,13 +72,14 @@ library Math {
|
|||||||
*/
|
*/
|
||||||
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
||||||
unchecked {
|
unchecked {
|
||||||
// Gas optimization: this is cheaper than requiring 'a' not being zero, but the
|
|
||||||
// benefit is lost if 'b' is also tested.
|
|
||||||
// See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
|
|
||||||
if (a == 0) return (true, 0);
|
|
||||||
uint256 c = a * b;
|
uint256 c = a * b;
|
||||||
if (c / a != b) return (false, 0);
|
assembly ("memory-safe") {
|
||||||
return (true, c);
|
// Only true when the multiplication doesn't overflow
|
||||||
|
// (c / a == b) || (a == 0)
|
||||||
|
success := or(eq(div(c, a), b), iszero(a))
|
||||||
|
}
|
||||||
|
// equivalent to: success ? c : 0
|
||||||
|
result = c * SafeCast.toUint(success);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,8 +88,11 @@ library Math {
|
|||||||
*/
|
*/
|
||||||
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
||||||
unchecked {
|
unchecked {
|
||||||
if (b == 0) return (false, 0);
|
success = b > 0;
|
||||||
return (true, a / b);
|
assembly ("memory-safe") {
|
||||||
|
// The `DIV` opcode returns zero when the denominator is 0.
|
||||||
|
result := div(a, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,11 +101,38 @@ library Math {
|
|||||||
*/
|
*/
|
||||||
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
|
||||||
unchecked {
|
unchecked {
|
||||||
if (b == 0) return (false, 0);
|
success = b > 0;
|
||||||
return (true, a % b);
|
assembly ("memory-safe") {
|
||||||
|
// The `MOD` opcode returns zero when the denominator is 0.
|
||||||
|
result := mod(a, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
|
||||||
|
*/
|
||||||
|
function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
|
||||||
|
(bool success, uint256 result) = tryAdd(a, b);
|
||||||
|
return ternary(success, result, type(uint256).max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
|
||||||
|
*/
|
||||||
|
function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
|
||||||
|
(, uint256 result) = trySub(a, b);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
|
||||||
|
*/
|
||||||
|
function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
|
||||||
|
(bool success, uint256 result) = tryMul(a, b);
|
||||||
|
return ternary(success, result, type(uint256).max);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
|
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
|
||||||
*
|
*
|
||||||
@ -192,7 +224,7 @@ library Math {
|
|||||||
|
|
||||||
// Make division exact by subtracting the remainder from [high low].
|
// Make division exact by subtracting the remainder from [high low].
|
||||||
uint256 remainder;
|
uint256 remainder;
|
||||||
assembly {
|
assembly ("memory-safe") {
|
||||||
// Compute remainder using mulmod.
|
// Compute remainder using mulmod.
|
||||||
remainder := mulmod(x, y, denominator)
|
remainder := mulmod(x, y, denominator)
|
||||||
|
|
||||||
@ -205,7 +237,7 @@ library Math {
|
|||||||
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
|
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.
|
||||||
|
|
||||||
uint256 twos = denominator & (0 - denominator);
|
uint256 twos = denominator & (0 - denominator);
|
||||||
assembly {
|
assembly ("memory-safe") {
|
||||||
// Divide denominator by twos.
|
// Divide denominator by twos.
|
||||||
denominator := div(denominator, twos)
|
denominator := div(denominator, twos)
|
||||||
|
|
||||||
|
|||||||
@ -168,6 +168,62 @@ describe('Math', function () {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('saturatingAdd', function () {
|
||||||
|
it('adds correctly', async function () {
|
||||||
|
const a = 5678n;
|
||||||
|
const b = 1234n;
|
||||||
|
await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
|
||||||
|
await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
|
||||||
|
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('bounds on addition overflow', async function () {
|
||||||
|
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
|
||||||
|
await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
|
||||||
|
ethers.MaxUint256,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('saturatingSub', function () {
|
||||||
|
it('subtracts correctly', async function () {
|
||||||
|
const a = 5678n;
|
||||||
|
const b = 1234n;
|
||||||
|
await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
|
||||||
|
await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
|
||||||
|
await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
|
||||||
|
await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
|
||||||
|
await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('bounds on subtraction overflow', async function () {
|
||||||
|
await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
|
||||||
|
await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
|
||||||
|
await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
|
||||||
|
await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('saturatingMul', function () {
|
||||||
|
it('multiplies correctly', async function () {
|
||||||
|
const a = 1234n;
|
||||||
|
const b = 5678n;
|
||||||
|
await testCommutative(this.mock.$saturatingMul, a, b, a * b);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('multiplies by zero correctly', async function () {
|
||||||
|
const a = 0n;
|
||||||
|
const b = 5678n;
|
||||||
|
await testCommutative(this.mock.$saturatingMul, a, b, 0n);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('bounds on multiplication overflow', async function () {
|
||||||
|
const a = ethers.MaxUint256;
|
||||||
|
const b = 2n;
|
||||||
|
await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('max', function () {
|
describe('max', function () {
|
||||||
it('is correctly detected in both position', async function () {
|
it('is correctly detected in both position', async function () {
|
||||||
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
|
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
|
||||||
|
|||||||
Reference in New Issue
Block a user