Refactor abs without logical branching (#4497)
Co-authored-by: Francisco Giordano <fg@frang.io> Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com> Co-authored-by: ernestognw <ernestognw@gmail.com>
This commit is contained in:
committed by
GitHub
parent
72c642e13e
commit
dfae50fa5b
@ -36,8 +36,15 @@ library SignedMath {
|
|||||||
*/
|
*/
|
||||||
function abs(int256 n) internal pure returns (uint256) {
|
function abs(int256 n) internal pure returns (uint256) {
|
||||||
unchecked {
|
unchecked {
|
||||||
// must be unchecked in order to support `n = type(int256).min`
|
// Formula from the "Bit Twiddling Hacks" by Sean Eron Anderson.
|
||||||
return uint256(n >= 0 ? n : -n);
|
// Since `n` is a signed integer, the generated bytecode will use the SAR opcode to perform the right shift,
|
||||||
|
// taking advantage of the most significant (or "sign" bit) in two's complement representation.
|
||||||
|
// This opcode adds new most significant bits set to the value of the previous most significant bit. As a result,
|
||||||
|
// the mask will either be `bytes(0)` (if n is positive) or `~bytes32(0)` (if n is negative).
|
||||||
|
int256 mask = n >> 255;
|
||||||
|
|
||||||
|
// A `bytes(0)` mask leaves the input unchanged, while a `~bytes32(0)` mask complements it.
|
||||||
|
return uint256((n + mask) ^ mask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
70
test/utils/math/SignedMath.t.sol
Normal file
70
test/utils/math/SignedMath.t.sol
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
pragma solidity ^0.8.20;
|
||||||
|
|
||||||
|
import {Test} from "forge-std/Test.sol";
|
||||||
|
|
||||||
|
import {Math} from "../../../contracts/utils/math/Math.sol";
|
||||||
|
import {SignedMath} from "../../../contracts/utils/math/SignedMath.sol";
|
||||||
|
|
||||||
|
contract SignedMathTest is Test {
|
||||||
|
// MIN
|
||||||
|
function testMin(int256 a, int256 b) public {
|
||||||
|
int256 result = SignedMath.min(a, b);
|
||||||
|
|
||||||
|
assertLe(result, a);
|
||||||
|
assertLe(result, b);
|
||||||
|
assertTrue(result == a || result == b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MAX
|
||||||
|
function testMax(int256 a, int256 b) public {
|
||||||
|
int256 result = SignedMath.max(a, b);
|
||||||
|
|
||||||
|
assertGe(result, a);
|
||||||
|
assertGe(result, b);
|
||||||
|
assertTrue(result == a || result == b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// AVERAGE
|
||||||
|
// 1. simple test, not full int256 range
|
||||||
|
function testAverage1(int256 a, int256 b) public {
|
||||||
|
a = bound(a, type(int256).min / 2, type(int256).max / 2);
|
||||||
|
b = bound(b, type(int256).min / 2, type(int256).max / 2);
|
||||||
|
|
||||||
|
int256 result = SignedMath.average(a, b);
|
||||||
|
|
||||||
|
assertEq(result, (a + b) / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. more complex test, full int256 range
|
||||||
|
function testAverage2(int256 a, int256 b) public {
|
||||||
|
(int256 result, int256 min, int256 max) = (
|
||||||
|
SignedMath.average(a, b),
|
||||||
|
SignedMath.min(a, b),
|
||||||
|
SignedMath.max(a, b)
|
||||||
|
);
|
||||||
|
|
||||||
|
// average must be between `a` and `b`
|
||||||
|
assertGe(result, min);
|
||||||
|
assertLe(result, max);
|
||||||
|
|
||||||
|
unchecked {
|
||||||
|
// must be unchecked in order to support `a = type(int256).min, b = type(int256).max`
|
||||||
|
uint256 deltaLower = uint256(result - min);
|
||||||
|
uint256 deltaUpper = uint256(max - result);
|
||||||
|
uint256 remainder = uint256((a & 1) ^ (b & 1));
|
||||||
|
assertEq(remainder, Math.max(deltaLower, deltaUpper) - Math.min(deltaLower, deltaUpper));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ABS
|
||||||
|
function testAbs(int256 a) public {
|
||||||
|
uint256 result = SignedMath.abs(a);
|
||||||
|
|
||||||
|
unchecked {
|
||||||
|
// must be unchecked in order to support `n = type(int256).min`
|
||||||
|
assertEq(result, a < 0 ? uint256(-a) : uint256(a));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user