Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast, branchless unsigned int absolute difference

Tags:

I have a program which spends most of its time computing the Euclidean distance between RGB values (3-tuples of unsigned 8-bit Word8). I need a fast, branchless unsigned int absolute difference function such that

unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b

in particular,

unsigned_difference a b == unsigned_difference b a

I came up with the following, using new primops from GHC 7.8:

-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
    I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]

which ghc -O2 -S compiles to

.Lc42U:
    movq 7(%rbx),%rax
    movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
    movq 8(%rbp),%rbx
    movq %rbx,%rcx
    subq %rax,%rcx
    cmpq %rax,%rbx
    setg %dl
    movzbl %dl,%edx
    imulq %rcx,%rdx
    movq %rax,%rcx
    subq %rbx,%rcx
    cmpq %rax,%rbx
    setl %al
    movzbl %al,%eax
    imulq %rcx,%rax
    addq %rdx,%rax
    movq %rax,(%r12)
    leaq -7(%r12),%rbx
    addq $16,%rbp
    jmp *(%rbp)

compiling with ghc -O2 -fllvm -optlo -O3 -S produces the following asm:

.LBB6_1:
    movq    7(%rbx), %rsi
    movq    $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
    movq    8(%rbp), %rcx
    movq    %rsi, %rdx
    subq    %rcx, %rdx
    xorl    %edi, %edi
    subq    %rsi, %rcx
    cmovleq %rdi, %rcx
    cmovgeq %rdi, %rdx
    addq    %rcx, %rdx
    movq    %rdx, 16(%rax)
    movq    16(%rbp), %rax
    addq    $16, %rbp
    leaq    -7(%r12), %rbx
    jmpq    *%rax  # TAILCALL

So LLVM manages to replace comparisons with (more efficient?) conditional move instructions. Unfortunately compiling with -fllvm has little effect on the runtime of my program.

However, there are a two problems with this function.

  • I want to compare Word8, but the comparison primops necessitate the use of Int. This causes needless allocation as I'm forced to store a 64-bit Int rather than a Word8.

I've profiled and confirmed that the use of fromIntegral :: Word8 -> Int is responsible for 42.4 percent of the program's total allocations.

  • My version uses 2 comparisons, 2 multiplications and 2 subtractions. I wonder if there is a more efficient method, using bitwise operations or SIMD instructions and exploiting the fact that I'm comparing Word8.

I had previously tagged the question C/C++ to attract attention from those more inclined to bit manipulation. My question uses Haskell, but I'd accept an answer implementing a correct method in any language.

Conclusion:

I've decided to use

w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
    where diff = fromIntegral a - fromIntegral b
          mask = unsafeShiftR diff 15

as it is faster than my original unsigned_difference function, and simple to implement. SIMD intrinsics in Haskell haven't reached maturity yet. So, while the SIMD versions are faster, I decided to go with a scalar version.

like image 561
cdk Avatar asked Mar 17 '14 00:03

cdk


3 Answers

Well, I tried to benchmark a bit. I use Criterion for the benchmarks, because it does proper significance tests. I also use QuickCheck here to ensure that all methods return the same results.

I compiled with GHC 7.6.3 (so I couldn't include your primops function, unfortunately) and with -O3:

ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff

Primarily we can see the difference between a naive implementation and a bit of fiddeling:

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
  where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        mask = unsafeShiftR v 63

Output:

benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....

benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...

I use the absolute integer value trick from "Bit Twiddling Hacks here". Unfortunately we need casts, I don't think that it is possible to solve the problem well in the domain of Word8 alone, but it seems sensible to use the native integer type anyway (there's definitely no need to create a heap object though ).

It doesn't really look like a large difference, but my test setup is also not perfect: I am mapping the function over a large list of random values to rule out branch prediction making the branching version seem more efficient than it is. This causes thunks to build up in memory, which could influence the timings a lot. When we subtract the constant overhead for maintaining the list, we could well see a lot more than the 20% speedup.

The generated assembly is actually pretty good (this is an inlined version of the function):

.Lc4BB:
    leaq 7(%rbx),%rax
    movq 8(%rbp),%rbx
    subq (%rax),%rbx
    movq %rbx,%rax
    sarq $63,%rax
    movq $base_GHCziInt_I64zh_con_info,-8(%r12)
    addq %rax,%rbx
    xorq %rax,%rbx
    movq %rbx,0(%r12)
    leaq -7(%r12),%rbx
    movq $s4z0_info,8(%rbp)

1 subtraction, 1 addition, 1 right-shift, 1 xor and no branch, as expected. Using the LLVM backend doesn't improve the runtime noticably.

Hope this is useful if you want to try out more stuff.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce

import Test.QuickCheck hiding ((.&.))
import Criterion.Main

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b

absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b

absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b

absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
  where v = a - b
        mask = unsafeShiftR v 15

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
  where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        !mask = unsafeShiftR v 63

absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a

{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
    {-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}

e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum

prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
    where x' = e2e x
          y' = e2e y

check = quickCheck prop_same1
     >> quickCheck prop_same2

instance (Random x, Random y) => Random (x, y) where
  random gen1 =
    let (x, gen2) = random gen1
        (y, gen3) = random gen2
    in ((x,y),gen3)

main =
    do check
       !pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
       let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
       defaultMain
         [ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
                                  , bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
                                  , bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
                                  ]
         , bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
                                  , bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
                                  ]
         {-, bgroup "absdiff_Int"   [ bench "1" $ whnf (absdiff1_int 13) 14-}
                                  {-, bench "2" $ whnf (absdiff3_int 13) 14-}
                                  {-]-}
         ]
like image 83
13 revs Avatar answered Sep 23 '22 01:09

13 revs


If you are targeting a system with SSE instructions you could use that for nice performance boost. I tested this against other posted methods and it seems to be the fastest approach.

Example results for diffing large amount of values:

diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms  // branchless mul add
diff3: 54.495269 ms  // branchless signed
diff4: 31.159628 ms  // sse
diff5: 30.855885 ms  // sse v2

My full test code below. I used SSE2 instructions, which are widely available in x86ish CPUs nowadays, through SSE intrinsics, which should be quite portable (MSVC, GCC, Clang, Intel compilers, etc.).

Notes:

  • Effectively this calculates max then min and then subtracts but does 16 values at once with each instruction.
  • Unrolling it in diff5 seems to have little effect, but possibly can be tweaked.
  • The fallback for last 15 or less values currently uses the signed trick method in a loop, but it could possibly be sped up further with unrolling and/or SSE.
  • The functions themselves are quite simple so they should be easily portable to anything with SSE intrinsics or asm.
  • I used Windows specific timing functions because std::chrono::high_resolution_clock has low precision in MSVC implementation, sorry for that, and for the dirty mix of C/C++ test code.
  • After timing the performance, the results are tested against reference branching implementation so they should be correct.

Please leave a comment if you have any questions/suggestions regarding the code or this approach in general.

#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>

#include <emmintrin.h> // sse2

// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
    }
}

// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
    }
}

// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
    }
}

// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    for (std::size_t j = n / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    const std::size_t UNROLL = 2;
    for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
        __m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        __m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        _mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
        _mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
    }
    for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

int main() {
    const std::size_t ALIGN = 16; // sse requires 16 bit align
    const std::size_t N = 10 * 1024 * 1024 * 3;

    auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
    auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));

    { // fill with random values
        std::mt19937 engine(std::random_device{}());
        std::uniform_int<std::uint8_t> distribution(0, 255);
        for (std::size_t i = 0; i < N; i++) {
            a[i] = distribution(engine);
            b[i] = distribution(engine);
        }
    }

    auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
    auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results

    LARGE_INTEGER f, t0, t1;
    QueryPerformanceFrequency(&f);

    QueryPerformanceCounter(&t0);
    diff0(a, b, res0, N);
    QueryPerformanceCounter(&t1);
    printf("diff0: %.6f ms\n",
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);

#define TEST(diffX)\
    QueryPerformanceCounter(&t0);\
    diffX(a, b, resX, N);\
    QueryPerformanceCounter(&t1);\
    printf("%s: %.6f ms\n", #diffX,\
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
    for (std::size_t i = 0; i < N; i++) {\
        if (resX[i] != res0[i]) {\
            printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
                a[i], b[i], resX[i], res0[i]);\
            break;\
        }\
    }

    TEST(diff1);
    TEST(diff2);
    TEST(diff3);
    TEST(diff4);
    TEST(diff5);

    _mm_free(a);
    _mm_free(b);
    _mm_free(res0);
    _mm_free(resX);

    getc(stdin);
    return 0;
}
like image 4
user2802841 Avatar answered Sep 22 '22 01:09

user2802841


Edit: Changing my answer, I had optimizations misconfigured for this.

I set up a quick test bed for this in C, and I'm finding that

a - b + (a < b) * ((b - a) << 1);

is a hair better, at least in my setup. The advantage of my approach is to eliminates a comparison. Your version implicitly handles a - b == 0 like its a separate case, when this is not necessary.

My test with yours takes

  • Your implementation: 371ms
  • This implementation: 324ms
  • Speedup: 14%

I tried an approach with a non-branching absolute value, and the results were better. Note that whether the inputs or outputs are considered signed or not by the compiler is irrelevant. It loops around large unsigned values, but since it only has to work on small values (as stated by the question), it should be sufficient.

s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
  • Your Implementation: 371ms
  • This implementation: 241ms
  • Speedup: 53%!
like image 3
VoidStar Avatar answered Sep 23 '22 01:09

VoidStar