Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do initialize an SIMD vector with a range from 0 to N?

I have the following function I'm trying to write an AXV version for:

void
hashids_shuffle(char *str, size_t str_length, char *salt, size_t salt_length)
{
    size_t i, j, v, p;
    char temp;

    if (!salt_length) {
        return;
    }

    for (i = str_length - 1, v = 0, p = 0; i > 0; --i, ++v) {
        v %= salt_length;
        p += salt[v];
        j = (salt[v] + v + p) % i;

        temp = str[i];
        str[i] = str[j];
        str[j] = temp;
    }
}

I'm trying to vectorize v %= salt_length;.
I want to initialize a vector that contains numbers from 0 to str_length in order to use SVML's _mm_rem_epu64 in order to calculate v for each loop iteration.
How do I initialize the vector correctly?

like image 695
the_drow Avatar asked Mar 12 '23 06:03

the_drow


1 Answers

Asking just how to initialize a vector is basically asking for a tutorial. Google up some of Intel's guides to using intrinsics. I'm going to pretend the question wasn't trivial, and answer about trying to implement this function efficiently. It's definitely not the kind of function you'd try to vectorize as a total beginner.

See the x86 tag wiki for links to docs, esp. Intel's intrinsics guide. See the sse tag wiki for a link to a very nice intro to SIMD programming with SSE intrinsics, and how to use SIMD effectively, among other links.

Contents summary:

  • Unrolling/vectorizing takes care of v % salt_length for free.
  • How you could vectorize v++; v %= loop_invariant; if it wasn't a power of 2 or compile-time constant. Includes an answer to the title question about using _mm_set_epi8 or other ways of initializing a vector for this purpose.
  • How to start vectorizing a complex loop like this: start with unrolling to find serial dependencies.
  • an untested version of the complete function that vectorizes everything except the % i and the swap. (i.e. vectorizes all the operations that were cheap anyway, like you asked).

    • (v + salt[v] + p) (and everything leading up to it) vectorizes to two vpaddw instructions. The prefix-sum setup outside the loop for vectorizing p was tricky, but I eventually vectorized it, too.

    • The vast majority of the function's run time will be in the scalar inner loop over a vector of j elements, bottlenecking on div (or whatever SVML can do), and/or cache misses with very large strings.


The entire loop can't easily vectorize because the swaps with pseudo-random indices create an unpredictable serial dependency. Using AVX512 gather + shuffle + scatter, with AVX512CD to find conflict bitmasks, might be possible, but that would have to be a separate question. I'm not sure how hard it would be to do this efficiently, or if you'd often end up repeating a vector shuffle many times, only making progress in one non-conflicting element.


Since salt_length = sizeof(size_t) is a compile-time constant and a power of 2 smaller than your vector length, v++ and v%=salt_length don't require any code inside the loop at all, and happens for free as a side-effect of effectively unrolling the loop to do multiple v values in parallel.

(Using a platform-dependent salt size means that a 32-bit build won't be able to process data created with 64-bit salt. Even the x32 ABI has 32-bit size_t, so changing to uint64_t would seem to make sense, unless you never need to share salted hashes between machines.)

In the scalar loop, v follows a repeating pattern of 0 1 2 3 0 1 2 3 ... (or 0..7, for 64-bit). In vector code, we're doing maybe 8 v values at once with 4B elements in a 32-byte vector, or 16 iterations at once with 2B elements.

So v becomes a loop-invariant constant vector. Interestingly, so does salt[v], so we never need to do any salt table lookups inside the loop. In fact, v+salt[v] can be pre-computed for scalar and vector.

The scalar version should pre-compute v+salt[v] and unroll by 4 or 8, too, removing the LUT lookup so all the memory/cache throughput is available for the actual swaps. The compiler probably won't do this for you, so you probably need to unroll manually and write extra code to handle the last odd number of string bytes. (Without unrolling, you could still pre-compute a lookup table of v+salt[v], with a type wide enough not to wrap around).

Even just making sure salt_length is known at compile time would also allow much better code. v %= compile_time_constant is cheaper than a div insn, and extremely cheap when it's a power of 2. (It just turns into v &= 7). The compiler might possibly do this for you if the scalar version can inline, or if you used salt_length = sizeof(size_t) instead of passing it as a function arg at all.


If you didn't already know salt_length: i.e. what @harold was suggesting before you revealed the critical information about salt_length:

Since we know v < salt_length to start with, we only ever need at most one v -= salt_length to wrap it back into the right range and maintain that invariant. This is called a "strength reduction" operation, because subtraction is a weaker (and cheaper) operation than division.

// The scalar loop would benefit from this transformation, too.
// or better, unroll the scalar loop by 8 so everything becomes constant
v++;
if( v >= salt_length)
    v-= salt_length;

To vectorize just this: let's pretend that all we know is salt_length <= 16, so we can use a vector of 32 uint8_t values. (And we can use pshufb to vectorize the salt[v] LUT lookup).

// untested  // Vectorizing  v++; v %= unknown_loop_invariant_value;

if (!salt_length) return;
assert(salt_length <= 16);  // so we can use pshufb for the salt[v] step

__m256i vvec = _mm256_setr_epi8(  // setr: lowest element first, unlike set
   0%salt_length,  1%salt_length,  2%salt_length,  3%salt_length, 
   4%salt_length,  5%salt_length,  6%salt_length,  7%salt_length,
   8%salt_length,  9%salt_length, 10%salt_length, 11%salt_length,
  12%salt_length, 13%salt_length, 14%salt_length, 15%salt_length,
  16%salt_length, 17%salt_length, 18%salt_length, 19%salt_length,
  20%salt_length, 21%salt_length, 22%salt_length, 23%salt_length,
  24%salt_length, 25%salt_length, 26%salt_length, 27%salt_length,
  28%salt_length, 29%salt_length, 30%salt_length, 31%salt_length);
__m256i v_increment = _mm256_set1_epi8(32 % salt_length);
__m256i vmodulus    = _mm256_set1_epi8(salt_length);

// salt_lut = _mm256_set1_epi64x(salt_byval);  // for known salt length. (pass it by val in a size_t arg, instead of by char*).

// duplicate the salt into both lanes of a vector.  Garbage beyond salt_length isn't looked at.
__m256i salt_lut = _mm256_broadcastsi128_si256(_mm_loadu_si128(salt));  // nevermind that this could segfault if salt is short and at the end of a page.

//__m256i v_plus_salt_lut = _mm256_add_epi8(vvec, salt_lut); // not safe with 8-bit elements: could wrap
// We could use 16-bit elements and AVX512 vpermw (or vpermi2w to support longer salts)

for (...) {
    vvec = _mm256_add_epi8(vvec, v_increment);         // ++v;

    // if(!(salt_length > v)) { v-= salt_length; }
    __m256i notlessequal = _mm256_cmpgt_epi8(vmodulus, vvec);  // all-ones where salt_length > v.
    //  all-zero where salt_length <= v, where we need to subtract salt_length
    __m256i conditional_sub = _mm256_and_si256(notlessequal, vmodulus)
    vvec = _mm256_sub_epi8(vvec, conditional_sub);   // subtract 0 or salt_length

    // salt[v] lookup:
    __m256i saltv = _mm256_shuffle_epi8(salt_lut, vvec);  // salt[v]

   // then maybe pmovzx and vextracti128+pmovzx to zero-extend to 16-bit elements?    Maybe vvec should only be a 16-bit vector?
   // or unpack lo/hi with zeros (but that behaves differently from pmovzx at the lane boundary)
   // or  have vvec already holding 16-bit elements with the upper half of each one always zero.  mask after the pshufb to re-zero,
   //   or do something clever with `vvec`, `v_increment` and `vmodulus` so `vvec` can have `0xff` in the odd bytes, so pshufb zeros those elements.
}

Of course, if we knew salt_length was a power of 2, we should have just masked off all but the relevant low bits in each element:

vvec = _mm256_add_epi8(vvec, _mm256_set1_epi8(salt_length));       // ++v;
vvec = _mm256_and_si256(vvec, _mm256_set1_epi8(salt_length - 1));  // v &= salt_length - 1; // aka v%=salt_length;

Noticing that we started with the wrong element size is when we realize that only vectorizing one line at a time was a bad idea, because now we have to go change all the code we already wrote to use wider elements, sometimes requiring a different strategy to do the same thing.

Of course, you do need to start with a rough outline, mental or written down, for how you might do each step separately. It's in the process of thinking this through that you see how the different parts might fit together.

For complex loops, a useful first step might be trying to manually unroll the scalar loop. That will help find serial dependencies, and things that simplify with unrolling.


(stuff) % i: The hard part

We need elements wide enough to hold the maximum value of i, because i is not a power of 2, and not constant, so that modulo operation takes work. Any wider is a waste and cuts our throughput. If we could vectorize the whole rest of the loop, it might even be worth specializing the function with different versions for different ranges of str_length. (Or maybe looping with 64b elements until i<= UINT32_MAX, then loop until i<=UINT16_MAX, etc). If you know you don't need to handle strings > 4GiB, you can speed up the common case by only using 32-bit math. (64-bit division is slower than 32-bit division, even when the upper bits are all zero).

Actually, I think we need elements as wide as the maximum p, since it keeps accumulating forever (until it wraps at 2^64 in the original scalar code). Unlike with a constant modulus, we can't just use p%=i to keep it in check, even though modulo is distributive. (123 % 33) % (33-16) != 123 % (33-16). Even aligning to 16 doesn't help: 12345 % 32 != 12345 % 48 % 32

This will quickly make p too large for repeated conditional subtraction of i (until the condition mask is all false), even for fairly large i values.

There are tricks for modulo by known integer constants (see http://libdivide.com/), but AFAIK working out the multiplicative modular inverse for a nearby divisor (even with a power-of-two stride like 16) isn't any easier than for a totally separate number. So we couldn't cheaply just adjust the constants for the next vector of i values.

The law of small numbers perhaps makes it worth-while to peel off the last couple vector iterations, with pre-computed vectors of multiplicative modular inverses so the % i can be done with vectors. Once we're close to the end of the string, it's probably hot in L1 cache so we're totally bottlenecked on div, not the swap loads/stores. For this, we'd maybe use a prologue to reach an i value that was a multiple of 16, so the last few vectors as we approach i=0 always have the same alignment of i values. Or else we'd have a LUT of constants for a range of i values, and simply do unaligned loads from it. That means we don't have to rotate salt_v and p.

Possibly converting to FP would be useful, because recent Intel CPUs (especially Skylake) have very powerful FP division hardware with significant pipelining (throughput : latency ratio). If we can get exact results with the right choice of rounding, this would be great. (float and double can exactly represent any integer up to about the size of their mantissa.)

I guess it's worth trying Intel's _mm_rem_epu16 (with a vector of i values that you decrement with a vector of set1(16)). If they use FP to get accurate results, that's great. If it just unpacks to scalar and does integer division, it would waste time getting values back in a vector.

Anyway, certainly the easiest solution is to iterate of vector elements with a scalar loop. Until you come up with something extremely fancy using AVX512CD for the swaps, it seems reasonable, but it's probably about an order of magnitude slower than just the swaps would be, if they're all hot in L1 cache.


(untest) partially-vectorized version of the function:

Here's the code on the Godbolt compiler explorer, with full design-notes comments, including the diagrams I made while figuring out the SIMD prefix-sum algo. I eventually remembered I'd seen a narrower version of this as a building block in @ZBoson's floating point SSE Prefix sum answer, but not until after mostly re-inventing it myself.

// See the godbolt link for full design-notes comments
// comments about what makes nice ASM or not.
#include <stdint.h>
#include <stdlib.h>
#include <immintrin.h>
#include <assert.h>

static inline
__m256i init_p_and_increment(size_t salt_length, __m256i *p_increment, __m256i saltv_u16, __m128i saltv_u8)
{  // return initial p vector (for first 16 i values).
   // return increment vector by reference.

  if (salt_length == 4) {
    assert(0); // unimplemented
    // should be about the same as length == 8.  Can maybe factor out some common parts, like up to psum2
  } else {
    assert(salt_length == 8);

    // SIMD prefix sum for n elements in a vector in O(log2(n)) steps.
    __m128i sv     = _mm256_castsi256_si128(saltv_u16);
    __m128i pshift1 = _mm_bslli_si128(sv, 2);        // 1 elem (uint16_t)
    __m128i psum1   = _mm_add_epi16(pshift1, sv);
    __m128i pshift2 = _mm_bslli_si128(psum1, 4);     // 2 elem
    __m128i psum2   = _mm_add_epi16(pshift2, psum1);
    __m128i pshift3 = _mm_bslli_si128(psum2, 8);     // 4 elem
    __m128i psum3   = _mm_add_epi16(pshift3, psum2); // p_initial low 128.  2^3 = 8 elements = salt_length
    // psum3 = the repeating pattern of p values.  Later values just add sum(salt[0..7]) to every element
     __m128i p_init_low = psum3;

    __m128i sum8_low = _mm_sad_epu8(saltv_u8, _mm_setzero_si128());  // sum(s0..s7) in each 64-bit half
    // alternative:
    //        sum8_low = _mm_bsrli_si128(p_init_low, 14); // has to wait for psum3 to be ready: lower ILP than doing psadbw separately
    __m256i sum8 = _mm256_broadcastw_epi16(sum8_low);

    *p_increment = _mm256_slli_epi16(sum8, 1);   // set1_epi16(2*sum(salt[0..7]))

    __m128i p_init_high = _mm_add_epi16(p_init_low, _mm256_castsi256_si128(sum8));
    __m256i p_init = _mm256_castsi128_si256(p_init_low);
    p_init = _mm256_inserti128_si256(p_init, p_init_high, 1);
      // not supported by gcc _mm256_set_m128i(p_init_high, psum3);

    return p_init;
  }

}

void
hashids_shuffle_simd(char *restrict str, size_t str_length, size_t salt_byval)
{
    //assert(salt_length <= 16); // so we can use pshufb for the salt[v] step for non-constant salt length.

    // platform-dependent salt size seems weird. Why not uint64_t?
    size_t salt_length = sizeof(size_t);

    assert(str_length-1 < UINT16_MAX);   // we do p + v + salt[v] in 16-bit elements
    // TODO: assert((str_length-1)/salt_length * p_increment < UINT16_MAX);

    __m128i saltv_u8;
    __m256i v, saltv;
    if(salt_length == 4) {
          v = _mm256_set1_epi64x(0x0003000200010000);   // `v%salt_length` is 0 1 2 3 repeating
      saltv_u8 = _mm_set1_epi32( salt_byval );
      saltv = _mm256_cvtepu8_epi16( saltv_u8 );         // salt[v] repeats with the same pattern: expand it to 16b elements with pmovzx
    } else {
        assert(salt_length == 8);
            v = _mm256_cvtepu8_epi16( _mm_set1_epi64x(0x0706050403020100) );
        saltv_u8 = _mm_set1_epi64x( salt_byval );
        saltv = _mm256_cvtepu8_epi16( saltv_u8 );
    }

    __m256i v_saltv = _mm256_add_epi16(v, saltv);

    __m256i p_increment;
    __m256i p = init_p_and_increment(salt_length, &p_increment, saltv, saltv_u8);


    for (unsigned i=str_length-1; i>0 ; /*i-=16 */){
        // 16 uint16_t j values per iteration.  i-- happens inside the scalar shuffle loop.
        p = _mm256_add_epi16(p, p_increment);    // p += salt[v]; with serial dependencies accounted for, prefix-sum style

        __m256i j_unmodded = _mm256_add_epi16(v_saltv, p);

        // size_t j = (v + saltv[v] + p) % i;
        //////// scalar loop over 16 j elements, doing the modulo and swap
        // alignas(32) uint16_t jbuf[16];   // portable C++11 syntax
        uint16_t jbuf[16] __attribute__((aligned(32)));  // GNU C syntax
        _mm256_store_si256((__m256i*)jbuf, j_unmodded);

        const int jcount = sizeof(jbuf)/sizeof(jbuf[0]);
        for (int elem = 0 ; elem < jcount ; elem++) {
          if (--i == 0) break;  // in fact returns from the whole function.

              // 32-bit division is significantly faster than 64-bit division
          unsigned j = jbuf[elem] % (uint32_t)i;
          // doubtful that vectorizing this with Intel SVML _mm_rem_epu16 would be a win
          // since there's no hardware support for it.  Until AVX512CD, we need each element in a gp reg as an array index anyway.

          char temp = str[i];
          str[i] = str[j];
          str[j] = temp;
        }

    }
}

This compiles to asm that looks about right, but I haven't run it.

Clang makes a fairly sensible inner loop. This is with -fno-unroll-loops for readability. Leave that out for performance, although it won't matter here since loop overhead isn't the bottleneck.

 # The loop part of clang3.8.1's output.  -march=haswell -fno-unroll-loops (only for human readability.  Normally it unrolls by 2).
.LBB0_6:  # outer loop                  #   in Loop: Header=BB0_3 Depth=1
    add     esi, 1
.LBB0_3:  # first iteration entry point # =>This Loop Header: Depth=1
    vpaddw  ymm2, ymm2, ymm1           # p += p_increment
    vpaddw  ymm3, ymm0, ymm2           # v+salt[v] + p
    vmovdqa ymmword ptr [rsp], ymm3    # store jbuf
    add     esi, -1
    lea     r8, [rdi + rsi]
    mov     ecx, 1
.LBB0_4:  # inner loop                  #   Parent Loop BB0_3 Depth=1
    # gcc's version fully unrolls the inner loop, leading to code bloat
    test    esi, esi                            # if(i==0) return
    je      .LBB0_8
    movzx   eax, word ptr [rsp + 2*rcx - 2]     # load jbuf
    xor     edx, edx
    div     esi
    mov     r9b, byte ptr [r8]                  # swap
    mov     al, byte ptr [rdi + rdx]
    mov     byte ptr [r8], al
    mov     byte ptr [rdi + rdx], r9b
    add     esi, -1
    add     r8, -1
    cmp     rcx, 16                     # silly clang, not macro-fusing cmp/jl because it wants to use a weird way to increment.
    lea     rcx, [rcx + 1]
    jl      .LBB0_4                     # inner loop
    jmp     .LBB0_6                     # outer loop
like image 79
Peter Cordes Avatar answered Mar 24 '23 00:03

Peter Cordes