I have a program which spends most of its time computing the Euclidean distance between RGB values (3-tuples of unsigned 8-bit Word8
). I need a fast, branchless unsigned int absolute difference function such that
unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b
in particular,
unsigned_difference a b == unsigned_difference b a
I came up with the following, using new primops from GHC 7.8:
-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]
which ghc -O2 -S
compiles to
.Lc42U:
movq 7(%rbx),%rax
movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
movq 8(%rbp),%rbx
movq %rbx,%rcx
subq %rax,%rcx
cmpq %rax,%rbx
setg %dl
movzbl %dl,%edx
imulq %rcx,%rdx
movq %rax,%rcx
subq %rbx,%rcx
cmpq %rax,%rbx
setl %al
movzbl %al,%eax
imulq %rcx,%rax
addq %rdx,%rax
movq %rax,(%r12)
leaq -7(%r12),%rbx
addq $16,%rbp
jmp *(%rbp)
compiling with ghc -O2 -fllvm -optlo -O3 -S
produces the following asm:
.LBB6_1:
movq 7(%rbx), %rsi
movq $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
movq 8(%rbp), %rcx
movq %rsi, %rdx
subq %rcx, %rdx
xorl %edi, %edi
subq %rsi, %rcx
cmovleq %rdi, %rcx
cmovgeq %rdi, %rdx
addq %rcx, %rdx
movq %rdx, 16(%rax)
movq 16(%rbp), %rax
addq $16, %rbp
leaq -7(%r12), %rbx
jmpq *%rax # TAILCALL
So LLVM manages to replace comparisons with (more efficient?) conditional move instructions. Unfortunately compiling with -fllvm
has little effect on the runtime of my program.
However, there are a two problems with this function.
Word8
, but the comparison primops necessitate the use of Int
. This causes needless allocation as I'm forced to store a 64-bit Int
rather than a Word8
.I've profiled and confirmed that the use of fromIntegral :: Word8 -> Int
is responsible for 42.4 percent of the program's total allocations.
Word8
. I had previously tagged the question C/C++
to attract attention from those more inclined to bit manipulation. My question uses Haskell, but I'd accept an answer implementing a correct method in any language.
Conclusion:
I've decided to use
w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
where diff = fromIntegral a - fromIntegral b
mask = unsafeShiftR diff 15
as it is faster than my original unsigned_difference
function, and simple to implement. SIMD intrinsics in Haskell haven't reached maturity yet. So, while the SIMD versions are faster, I decided to go with a scalar version.
Well, I tried to benchmark a bit. I use Criterion for the benchmarks, because it does proper significance tests. I also use QuickCheck here to ensure that all methods return the same results.
I compiled with GHC 7.6.3 (so I couldn't include your primops function, unfortunately) and with -O3
:
ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff
Primarily we can see the difference between a naive implementation and a bit of fiddeling:
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
mask = unsafeShiftR v 63
Output:
benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....
benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...
I use the absolute integer value trick from "Bit Twiddling Hacks here". Unfortunately we need casts, I don't think that it is possible to solve the problem well in the domain of Word8
alone, but it seems sensible to use the native integer type anyway (there's definitely no need to create a heap object though ).
It doesn't really look like a large difference, but my test setup is also not perfect: I am mapping the function over a large list of random values to rule out branch prediction making the branching version seem more efficient than it is. This causes thunks to build up in memory, which could influence the timings a lot. When we subtract the constant overhead for maintaining the list, we could well see a lot more than the 20% speedup.
The generated assembly is actually pretty good (this is an inlined version of the function):
.Lc4BB:
leaq 7(%rbx),%rax
movq 8(%rbp),%rbx
subq (%rax),%rbx
movq %rbx,%rax
sarq $63,%rax
movq $base_GHCziInt_I64zh_con_info,-8(%r12)
addq %rax,%rbx
xorq %rax,%rbx
movq %rbx,0(%r12)
leaq -7(%r12),%rbx
movq $s4z0_info,8(%rbp)
1 subtraction, 1 addition, 1 right-shift, 1 xor and no branch, as expected. Using the LLVM backend doesn't improve the runtime noticably.
Hope this is useful if you want to try out more stuff.
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce
import Test.QuickCheck hiding ((.&.))
import Criterion.Main
absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b
absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b
absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b
absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
where v = a - b
mask = unsafeShiftR v 15
absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
!mask = unsafeShiftR v 63
absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a
{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
{-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}
e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum
prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
where x' = e2e x
y' = e2e y
check = quickCheck prop_same1
>> quickCheck prop_same2
instance (Random x, Random y) => Random (x, y) where
random gen1 =
let (x, gen2) = random gen1
(y, gen3) = random gen2
in ((x,y),gen3)
main =
do check
!pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
defaultMain
[ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
, bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
, bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
]
, bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
, bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
]
{-, bgroup "absdiff_Int" [ bench "1" $ whnf (absdiff1_int 13) 14-}
{-, bench "2" $ whnf (absdiff3_int 13) 14-}
{-]-}
]
If you are targeting a system with SSE instructions you could use that for nice performance boost. I tested this against other posted methods and it seems to be the fastest approach.
Example results for diffing large amount of values:
diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms // branchless mul add
diff3: 54.495269 ms // branchless signed
diff4: 31.159628 ms // sse
diff5: 30.855885 ms // sse v2
My full test code below. I used SSE2 instructions, which are widely available in x86ish CPUs nowadays, through SSE intrinsics, which should be quite portable (MSVC, GCC, Clang, Intel compilers, etc.).
Notes:
diff5
seems to have little effect, but possibly can be tweaked.std::chrono::high_resolution_clock
has low precision in MSVC implementation, sorry for that, and for the dirty mix of C/C++ test code.Please leave a comment if you have any questions/suggestions regarding the code or this approach in general.
#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>
#include <emmintrin.h> // sse2
// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
}
}
// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
}
}
// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
}
}
// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
for (std::size_t i = 0; i < n; i++) {
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
for (std::size_t j = n / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
std::size_t n)
{
auto pA = reinterpret_cast<const __m128i*>(a);
auto pB = reinterpret_cast<const __m128i*>(b);
auto pRes = reinterpret_cast<__m128i*>(res);
std::size_t i = 0;
const std::size_t UNROLL = 2;
for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
__m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
__m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
__m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
_mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
_mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
}
for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
__m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
__m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
_mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
}
for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
std::int16_t diff = a[i] - b[i];
std::uint16_t mask = diff >> 15;
res[i] = (diff + mask) ^ mask;
}
}
int main() {
const std::size_t ALIGN = 16; // sse requires 16 bit align
const std::size_t N = 10 * 1024 * 1024 * 3;
auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
{ // fill with random values
std::mt19937 engine(std::random_device{}());
std::uniform_int<std::uint8_t> distribution(0, 255);
for (std::size_t i = 0; i < N; i++) {
a[i] = distribution(engine);
b[i] = distribution(engine);
}
}
auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results
LARGE_INTEGER f, t0, t1;
QueryPerformanceFrequency(&f);
QueryPerformanceCounter(&t0);
diff0(a, b, res0, N);
QueryPerformanceCounter(&t1);
printf("diff0: %.6f ms\n",
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);
#define TEST(diffX)\
QueryPerformanceCounter(&t0);\
diffX(a, b, resX, N);\
QueryPerformanceCounter(&t1);\
printf("%s: %.6f ms\n", #diffX,\
static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
for (std::size_t i = 0; i < N; i++) {\
if (resX[i] != res0[i]) {\
printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
a[i], b[i], resX[i], res0[i]);\
break;\
}\
}
TEST(diff1);
TEST(diff2);
TEST(diff3);
TEST(diff4);
TEST(diff5);
_mm_free(a);
_mm_free(b);
_mm_free(res0);
_mm_free(resX);
getc(stdin);
return 0;
}
Edit: Changing my answer, I had optimizations misconfigured for this.
I set up a quick test bed for this in C, and I'm finding that
a - b + (a < b) * ((b - a) << 1);
is a hair better, at least in my setup. The advantage of my approach is to eliminates a comparison. Your version implicitly handles a - b == 0
like its a separate case, when this is not necessary.
My test with yours takes
I tried an approach with a non-branching absolute value, and the results were better. Note that whether the inputs or outputs are considered signed or not by the compiler is irrelevant. It loops around large unsigned values, but since it only has to work on small values (as stated by the question), it should be sufficient.
s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With