Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast dot product of a bit vector and a floating point vector

I'm trying to compute the dot product between a floating and a bit vector in the most efficient manner on a i7. In reality, I'm doing this operation on either 128 or 256-dimensional vectors, but for illustration, let me write the code for 64-dimensions to illustrate the problem:

// a has 64 elements. b is a bitvector of 64 dimensions.
float dot(float *restrict a, uint64_t b) {
    float sum = 0;
    for(int i=0; b && i<64; i++, b>>=1) {
        if (b & 1) sum += a[i];
    }
    return sum;
}

This works, of course, but the problem is, this is the time-critical spot of the whole program (eats up the 95% CPU time of a 50 minutes run) so I desperately need to make it faster.

My guess is the branching above is the game killer (prevents out-of-order execution, causes bad branch prediction). I'm not sure if vector instructions could be used and helpful here. Using gcc 4.8 with -std=c99 -march=native -mtune=native -Ofast -funroll-loops, I'm currently getting this output

    movl    $4660, %edx
    movl    $5, %ecx
    xorps   %xmm0, %xmm0
    .p2align 4,,10
    .p2align 3
.L4:
    testb   $1, %cl
    je  .L2
    addss   (%rdx), %xmm0
.L2:
    leaq    4(%rdx), %rax
    shrq    %rcx
    testb   $1, %cl
    je  .L8
    addss   4(%rdx), %xmm0
.L8:
    shrq    %rcx
    testb   $1, %cl
    je  .L9
    addss   4(%rax), %xmm0
.L9:
    shrq    %rcx
    testb   $1, %cl
    je  .L10
    addss   8(%rax), %xmm0
.L10:
    shrq    %rcx
    testb   $1, %cl
    je  .L11
    addss   12(%rax), %xmm0
.L11:
    shrq    %rcx
    testb   $1, %cl
    je  .L12
    addss   16(%rax), %xmm0
.L12:
    shrq    %rcx
    testb   $1, %cl
    je  .L13
    addss   20(%rax), %xmm0
.L13:
    shrq    %rcx
    testb   $1, %cl
    je  .L14
    addss   24(%rax), %xmm0
.L14:
    leaq    28(%rax), %rdx
    shrq    %rcx
    cmpq    $4916, %rdx
    jne .L4
    ret

Edit It's okay to permute the data (as long as the permutation is the same for all parameters), the ordering doesn't matter.

I'm wondering if there is something that will work at >3x speed of Chris Dodd's SSE2 code.

New note: AVX/AVX2 code is also welcome!

Edit 2 Given a bitvector, I have to multiply it with 128 (or 256, if it's 256-bit) different float vectors (so it's also okay to involve more that a single float vector at a time). This is the whole process. Anything that will speed up the whole process is also welcome!

like image 333
John Smith Avatar asked Apr 17 '13 04:04

John Smith


1 Answers

The best bet is going to be to use SSE ps instructions that operate on 4 floats at a time. You can take advantage of the fact that a float 0.0 is all 0 bits to use an andps instruction to mask off the undesired elements:

#include <stdint.h>
#include <xmmintrin.h>

union {
    uint32_t i[4];
    __m128   xmm;
} mask[16] = {
 {  0,  0,  0,  0 },
 { ~0,  0,  0,  0 },
 {  0, ~0,  0,  0 },
 { ~0, ~0,  0,  0 },
 {  0,  0, ~0,  0 },
 { ~0,  0, ~0,  0 },
 {  0, ~0, ~0,  0 },
 { ~0, ~0, ~0,  0 },
 {  0,  0,  0, ~0 },
 { ~0,  0,  0, ~0 },
 {  0, ~0,  0, ~0 },
 { ~0, ~0,  0, ~0 },
 {  0,  0, ~0, ~0 },
 { ~0,  0, ~0, ~0 },
 {  0, ~0, ~0, ~0 },
 { ~0, ~0, ~0, ~0 },
};

float dot(__m128 *a, uint64_t b) {
    __m128 sum = { 0.0 };
    for (int i = 0; i < 16; i++, b>>=4)
        sum += _mm_and_ps(a[i], mask[b&0xf].xmm);
    return sum[0] + sum[1] + sum[2] + sum[3];
}

If you expect there to be a lot of 0s in the mask, it might be faster to short-cicruit the 0s:

for (int i = 0; b; i++, b >>= 4)
    if (b & 0xf)
        sum += _mm_and_ps(a[i], mask[b&0xf].xmm);

but if b is random, this will be slower.

like image 157
Chris Dodd Avatar answered Sep 19 '22 11:09

Chris Dodd