From dfae50fa5bc63cd6a7ac5eab54523266fa9689fe Mon Sep 17 00:00:00 2001 From: Vladislav Volosnikov Date: Thu, 18 Jan 2024 21:22:47 +0100 Subject: [PATCH] Refactor abs without logical branching (#4497) Co-authored-by: Francisco Giordano Co-authored-by: Hadrien Croubois Co-authored-by: ernestognw --- contracts/utils/math/SignedMath.sol | 11 ++++- test/utils/math/SignedMath.t.sol | 70 +++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 test/utils/math/SignedMath.t.sol diff --git a/contracts/utils/math/SignedMath.sol b/contracts/utils/math/SignedMath.sol index 66a615162..535d75761 100644 --- a/contracts/utils/math/SignedMath.sol +++ b/contracts/utils/math/SignedMath.sol @@ -36,8 +36,15 @@ library SignedMath { */ function abs(int256 n) internal pure returns (uint256) { unchecked { - // must be unchecked in order to support `n = type(int256).min` - return uint256(n >= 0 ? n : -n); + // Formula from the "Bit Twiddling Hacks" by Sean Eron Anderson. + // Since `n` is a signed integer, the generated bytecode will use the SAR opcode to perform the right shift, + // taking advantage of the most significant (or "sign" bit) in two's complement representation. + // This opcode adds new most significant bits set to the value of the previous most significant bit. As a result, + // the mask will either be `bytes(0)` (if n is positive) or `~bytes32(0)` (if n is negative). + int256 mask = n >> 255; + + // A `bytes(0)` mask leaves the input unchanged, while a `~bytes32(0)` mask complements it. + return uint256((n + mask) ^ mask); } } } diff --git a/test/utils/math/SignedMath.t.sol b/test/utils/math/SignedMath.t.sol new file mode 100644 index 000000000..fe5900a5d --- /dev/null +++ b/test/utils/math/SignedMath.t.sol @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; + +import {Math} from "../../../contracts/utils/math/Math.sol"; +import {SignedMath} from "../../../contracts/utils/math/SignedMath.sol"; + +contract SignedMathTest is Test { + // MIN + function testMin(int256 a, int256 b) public { + int256 result = SignedMath.min(a, b); + + assertLe(result, a); + assertLe(result, b); + assertTrue(result == a || result == b); + } + + // MAX + function testMax(int256 a, int256 b) public { + int256 result = SignedMath.max(a, b); + + assertGe(result, a); + assertGe(result, b); + assertTrue(result == a || result == b); + } + + // AVERAGE + // 1. simple test, not full int256 range + function testAverage1(int256 a, int256 b) public { + a = bound(a, type(int256).min / 2, type(int256).max / 2); + b = bound(b, type(int256).min / 2, type(int256).max / 2); + + int256 result = SignedMath.average(a, b); + + assertEq(result, (a + b) / 2); + } + + // 2. more complex test, full int256 range + function testAverage2(int256 a, int256 b) public { + (int256 result, int256 min, int256 max) = ( + SignedMath.average(a, b), + SignedMath.min(a, b), + SignedMath.max(a, b) + ); + + // average must be between `a` and `b` + assertGe(result, min); + assertLe(result, max); + + unchecked { + // must be unchecked in order to support `a = type(int256).min, b = type(int256).max` + uint256 deltaLower = uint256(result - min); + uint256 deltaUpper = uint256(max - result); + uint256 remainder = uint256((a & 1) ^ (b & 1)); + assertEq(remainder, Math.max(deltaLower, deltaUpper) - Math.min(deltaLower, deltaUpper)); + } + } + + // ABS + function testAbs(int256 a) public { + uint256 result = SignedMath.abs(a); + + unchecked { + // must be unchecked in order to support `n = type(int256).min` + assertEq(result, a < 0 ? uint256(-a) : uint256(a)); + } + } +}