Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Branchless K-means (or other optimizations)

People also ask

What is branchless programming?

Branchless programming is a programming technique that eliminates the branches (if, switch, and other conditional statements) from the program. Although this is not much relevant these days with extremely powerful systems and usage of interpreted languages( especially dynamic typed ones).

How do I stop branching in programming?

I believe the most common way to avoid branching is to leverage bit parallelism in reducing the total jumps present in your code. The longer the basic blocks, the less often the pipeline is flushed.


Too bad we can't use SSE4.1, but very well then, SSE2 it is. I haven't tested this, just compiled it to see if there were syntax errors and to see whether the assembly made sense (it's mostly alright, though GCC spills min_index even with some xmm registers not used, not sure why that happens)

int find_closest(float *x, float *y, float *z,
                 float pt_x, float pt_y, float pt_z, int n) {
    __m128i min_index = _mm_set_epi32(3, 2, 1, 0);
    __m128 xdif = _mm_sub_ps(_mm_set1_ps(pt_x), _mm_load_ps(x));
    __m128 ydif = _mm_sub_ps(_mm_set1_ps(pt_y), _mm_load_ps(y));
    __m128 zdif = _mm_sub_ps(_mm_set1_ps(pt_z), _mm_load_ps(z));
    __m128 min_dist = _mm_add_ps(_mm_add_ps(_mm_mul_ps(xdif, xdif), 
                                            _mm_mul_ps(ydif, ydif)), 
                                            _mm_mul_ps(zdif, zdif));
    __m128i index = min_index;
    for (int i = 4; i < n; i += 4) {
        xdif = _mm_sub_ps(_mm_set1_ps(pt_x), _mm_load_ps(x + i));
        ydif = _mm_sub_ps(_mm_set1_ps(pt_y), _mm_load_ps(y + i));
        zdif = _mm_sub_ps(_mm_set1_ps(pt_z), _mm_load_ps(z + i));
        __m128 dist = _mm_add_ps(_mm_add_ps(_mm_mul_ps(xdif, xdif), 
                                            _mm_mul_ps(ydif, ydif)), 
                                            _mm_mul_ps(zdif, zdif));
        index = _mm_add_epi32(index, _mm_set1_epi32(4));
        __m128i mask = _mm_castps_si128(_mm_cmplt_ps(dist, min_dist));
        min_dist = _mm_min_ps(min_dist, dist);
        min_index = _mm_or_si128(_mm_and_si128(index, mask), 
                                 _mm_andnot_si128(mask, min_index));
    }
    float mdist[4];
    _mm_store_ps(mdist, min_dist);
    uint32_t mindex[4];
    _mm_store_si128((__m128i*)mindex, min_index);
    float closest = mdist[0];
    int closest_i = mindex[0];
    for (int i = 1; i < 4; i++) {
        if (mdist[i] < closest) {
            closest = mdist[i];
            closest_i = mindex[i];
        }
    }
    return closest_i;
}

As usual, it expects the pointers to be 16-aligned. Also, the padding should be with points at infinity (so they're never closest to the target).

SSE 4.1 would let you replace this

min_index = _mm_or_si128(_mm_and_si128(index, mask), 
                         _mm_andnot_si128(mask, min_index));

By this

min_index = _mm_blendv_epi8(min_index, index, mask);

Here's an asm version, made for vsyasm, tested a bit (seems to work)

bits 64

section .data

align 16
centroid_four:
    dd 4, 4, 4, 4
centroid_index:
    dd 0, 1, 2, 3

section .text

global find_closest

proc_frame find_closest
    ;
    ;   arguments:
    ;       ecx: number of points (multiple of 4 and at least 4)
    ;       rdx -> array of 3 pointers to floats (x, y, z) (the points)
    ;       r8 -> array of 3 floats (the reference point)
    ;
    alloc_stack 0x58
    save_xmm128 xmm6, 0
    save_xmm128 xmm7, 16
    save_xmm128 xmm8, 32
    save_xmm128 xmm9, 48
[endprolog]
    movss xmm0, [r8]
    shufps xmm0, xmm0, 0
    movss xmm1, [r8 + 4]
    shufps xmm1, xmm1, 0
    movss xmm2, [r8 + 8]
    shufps xmm2, xmm2, 0
    ; pointers to x, y, z in r8, r9, r10
    mov r8, [rdx]
    mov r9, [rdx + 8]
    mov r10, [rdx + 16]
    ; reference point is in xmm0, xmm1, xmm2 (x, y, z)
    movdqa xmm3, [rel centroid_index]   ; min_index
    movdqa xmm4, xmm3                   ; current index
    movdqa xmm9, [rel centroid_four]     ; index increment
    paddd xmm4, xmm9
    ; calculate initial min_dist, xmm5
    movaps xmm5, [r8]
    subps xmm5, xmm0
    movaps xmm7, [r9]
    subps xmm7, xmm1
    movaps xmm8, [r10]
    subps xmm8, xmm2
    mulps xmm5, xmm5
    mulps xmm7, xmm7
    mulps xmm8, xmm8
    addps xmm5, xmm7
    addps xmm5, xmm8
    add r8, 16
    add r9, 16
    add r10, 16
    sub ecx, 4
    jna _tail
_loop:
    movaps xmm6, [r8]
    subps xmm6, xmm0
    movaps xmm7, [r9]
    subps xmm7, xmm1
    movaps xmm8, [r10]
    subps xmm8, xmm2
    mulps xmm6, xmm6
    mulps xmm7, xmm7
    mulps xmm8, xmm8
    addps xmm6, xmm7
    addps xmm6, xmm8
    add r8, 16
    add r9, 16
    add r10, 16
    movaps xmm7, xmm6
    cmpps xmm6, xmm5, 1
    minps xmm5, xmm7
    movdqa xmm7, xmm6
    pand xmm6, xmm4
    pandn xmm7, xmm3
    por xmm6, xmm7
    movdqa xmm3, xmm6
    paddd xmm4, xmm9
    sub ecx, 4
    ja _loop
_tail:
    ; calculate horizontal minumum
    pshufd xmm0, xmm5, 0xB1
    minps xmm0, xmm5
    pshufd xmm1, xmm0, 0x4E
    minps xmm0, xmm1
    ; find index of the minimum
    cmpps xmm0, xmm5, 0
    movmskps eax, xmm0
    bsf eax, eax
    ; index into xmm3, sort of
    movaps [rsp + 64], xmm3
    mov eax, [rsp + 64 + rax * 4]
    movaps xmm9, [rsp + 48]
    movaps xmm8, [rsp + 32]
    movaps xmm7, [rsp + 16]
    movaps xmm6, [rsp]
    add rsp, 0x58
    ret
endproc_frame

In C++:

extern "C" int find_closest(int n, float** points, float* reference_point);

You could use a branchless ternary operator, sometimes called bitselect ( condition ? true : false).
Just use it for the 2 members, defaulting to doing nothing.
Don't worry about the extra operations, they are nothing compared to the if statement branching.

bitselect implementation:

inline static int bitselect(int condition, int truereturnvalue, int falsereturnvalue)
{
    return (truereturnvalue & -condition) | (falsereturnvalue & ~(-condition)); //a when TRUE and b when FALSE
}

inline static float bitselect(int condition, float truereturnvalue, float falsereturnvalue)
{
    //Reinterpret floats. Would work because it's just a bit select, no matter the actual value
    int& at = reinterpret_cast<int&>(truereturnvalue);
    int& af = reinterpret_cast<int&>(falsereturnvalue);
    int res = (at & -condition) | (af & ~(-condition)); //a when TRUE and b when FALSE
    return  reinterpret_cast<float&>(res);
}

And your loop should look like this:

for (int i=0; i < num_centroids; ++i)
{
  const ClusterCentroid& cent = centroids[i];
  const float dist = ...;
  bool isSmaeller = dist < pt.min_dist;

  //use same value if not smaller
  pt.min_index = bitselect(isSmaeller, i, pt.min_index);
  pt.min_dist = bitselect(isSmaeller, dist, pt.min_dist);
}

C++ is a high-level language. Your assumption that control flow in the C++ source code translates into branching instructions is flawed. I don't have the definition of some types from your example, so I made a simple test program with similar conditional assignments:

int g(int, int);

int f(const int *arr)
{
    int min = 10000, minIndex = -1;
    for ( int i = 0; i < 1000; ++i )
    {
        if ( arr[i] < min )
        {
            min = arr[i];
            minIndex = i;
        }
    }
    return g(min, minIndex);
}

Note that the use of the undefined "g" is merely to prevent the optimizer from deleting everything. I translated this with G++ 4.9.2 with -O3 and -S into x86_64 assembly (without even having to change the default for -march) and the (not overly surprising) result is that the loop body contains no branches

movl    (%rdi,%rax,4), %ecx
movl    %edx, %r8d
cmpl    %edx, %ecx
cmovle  %ecx, %r8d
cmovl   %eax, %esi
addq    $1, %rax

Apart from that, the assumption that branchless is necessarily faster may also be flawed because the probability that a new distance "beats" the old is decreasing the more elements you have looked at. It's not a coin toss. The "bitselect" trick was invented when compilers were much less aggressive at generating "as-if" assembly than they are today. I would much rather suggest to take a look at the kind of assembly your compiler is actually generating before either trying to rework the code so the compiler is better able to optimize it, or taking the result as a basis for hand-written assembly. If you want to look into SIMD, I would suggest trying a "minimum of minimums" approach with reduced data dependencies (in my example, the dependencies on "min" are probably a bottleneck).


Firstly, I'd suggest that before you try any code changes you look at the disassembly in an optimized build. Ideally you want to look at the profiler data at an assembly level. This can show up various things, for example:

  1. The compiler may not have generated an actual branch instruction.
  2. The line of code that has the bottleneck may have many more instructions associated with it than you might think - the dist calculation for example.

In addition to that there's the standard trick that when you're talking about distances computing them often requires a square root. You should do that square root at the end of the process on the minimum squared value.

SSE can process four values at once, without any branches, using _mm_min_ps. If you really need speed then you want to be using SSE (or AVX) intrinsics. Here's a basic example:

  float MinimumDistance(const float *values, int count)
  {
    __m128 min = _mm_set_ps(FLT_MAX, FLT_MAX, FLT_MAX, FLT_MAX);
    int i=0;
    for (; i < count - 3; i+=4)
    {
        __m128 distances = _mm_loadu_ps(&values[i]);
        min = _mm_min_ps(min, distances);
    }
    // Combine the four separate minimums to a single value
    min = _mm_min_ps(min, _mm_shuffle_ps(min, min, _MM_SHUFFLE(2, 3, 0, 1)));
    min = _mm_min_ps(min, _mm_shuffle_ps(min, min, _MM_SHUFFLE(1, 0, 3, 2)));

    // Deal with the last 0-3 elements the slow way
    float result = FLT_MAX;
    if (count > 3) _mm_store_ss(&result, min);
    for (; i < count; i++)
    {
        result = min(values[i], result);
    }

    return result;
  }

For best SSE performance you should make sure the loads happen at aligned addresses. You can handle the first few misaligned elements in the same way as the last few in the code above if necessary.

The other thing to watch out for is memory bandwidth. If there's several members of the ClusterCentroid structure that you don't use during that loop then you'll be reading much more data from memory than you really need to as memory is read in cache line sized chunks, which are 64 bytes each.


This might go both ways, but I'd give the following structure a try:

std::vector<float> centDists(num_centroids); //<-- one for each thread. 
for (size_t p=0; p<num_points; ++p) {
    Point& pt = points[p];
    for (size_t c=0; c<num_centroids; ++c) {
        const float dist = ...;
        centDists[c]=dist;
    }
    pt.min_idx it= min_element(centDists.begin(),centDists.end())-centDists.begin();    
}

Obviously, you now have to iterate two times over memory, which probably hurts the cache hit to miss ratio (you could also split it into sub ranges) but on the other hand, each of the inner loops should be easy to vectorize and unroll - so you just have to measure whether it is worth it.

And even if you stick to your version, I'd try using local variables to keep track of the minimum index and distance and apply the results to point at the end.
The rational is, that each read or write to pt.min_dist is effectively done through a pointer, which - depending on the compiler optimizations - may or may not decrease your performance.

Another thing that is important for vectorizations is to turn an array of Structs (in this case cententroids) into a struct of arrays (So e.g. one array for each coordinate of the points), because that way you don't need extra gather instructions in order to load the data for usage with SIMD instructions. See Eric Brumer's talk for more information on that topic.

EDIT: Some numbers for my system (haswell, clang 3.5):
I did a short test with your benchmark and on my system, above code slowed the algorithm down by about 10% - essentially, nothing could be vectorized.

However, when applying the AoS to SoA transformation for your centroids, the distance calculation was vectorized, which lead to a reduction of the overall runtime of about 40% compared to your original structure with applied AoS to SoA transformation.


One possible micro-optimizations: Store min_dist and min_index in local variables. The compiler may have to write to memory more often the way you have it written; on some architectures this can have a big performance impact. See my answer here for another example.

Adams's suggestion of doing 4 compares at once is also a good one.

However, your best speedup is going to come from reducing the number of centroids you have to check. Ideally, build a kd-tree (or similar) around the centroids, then query that to find the closest point.

If you don't have any tree building code lying around, here's my favorite "poor-man's" closest point search:

Sort the points by one coordinate, e.g. cent.pos[0]
Pick a starting index for the query point (pt)
Iterate forwards through the candidate points until you reach the end, OR when abs(pt.pos[0] - cent.pos[0]) > min_dist
Repeat the previous step going the opposite direction.

The extra stopping condition for the search means that you should skip a fair amount of points; you're also guaranteed not to skip any points closer than the best you've already found.

So for your code, this looks something like

// sort centroid by x coordinate.
min_index = -1;
min_dist = numeric_limits<float>::max();

// pick the start index. This works well if the points are evenly distributed.
float min_x = centroids[0].pos[0];
float max_x = centroids[num_centroids-1].pos[0];
float cur_x = pt.pos[0];
float t = (max_x - cur_x) / (max_x - min_x);
// TODO clamp t between 0 and 1
int start_index = int(t * float(num_centroids))

// Forward search
for (int i=start_index ; i < num_centroids; ++i)
{
    const ClusterCentroid& cent = centroids[i];
    if (fabs(cent.pos[0] - pt.pos[0]) > min_i)
        // Everything to the right of this must be further min_dist, so break.
        // This is where the savings comes from!
        break; 
    const float dist = ...;
    if (dist < min_dist)
    {
        min_dist = dist;
        min_index = i;
    }
}

// Backwards search
for (int i=start_index ; i >= 0; --i)
{
    // same as above
}
pt.min_dist = min_dist
pt.min_index = min_index

(Note that this assumes you're computing the distance between points, but your assembly indicates it's the distance squared. Adjust the break condition accordingly).

There's slight overhead to building the tree or sorting the centroids, but this should be offset by making the calculations faster in the bigger loop (over the number of points).