Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement lane crossing logical bit-wise shift/rotate (left and right) in AVX2

Tags:

c++

c

avx2

As per this answer, I've created the following test program:

#include <iso646.h>
#include <immintrin.h>

#include <stdio.h>

#define SHIFT_LEFT( N ) \ 
\
    inline __m256i shift_left_##N ( __m256i A  ) { \
\
    if ( N == 0 ) return A; \
    else if ( N <  16 ) return _mm256_alignr_epi8 ( A, _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ), ( uint8_t ) ( 16 - N ) ); \
    else if ( N == 16 ) return _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ); \
    else return _mm256_slli_si256 ( _mm256_permute2x128_si256 ( A, A, _MM_SHUFFLE ( 0, 0, 2, 0 ) ), ( uint8_t ) ( N - 16 ) ); \
}

void print ( const size_t n ) {

    size_t i = 0x8000000000000000;

    while ( i ) {

        putchar ( ( int ) ( n & i ) + ( int ) ( 48 ) );
        i >>= 1;
        putchar ( ( int ) ( n & i ) + ( int ) ( 48 ) );
        i >>= 1;

        putchar ( ' ' );
    }
}

SHIFT_LEFT ( 2 );

int main ( ) {

    __m256i a = _mm256_set_epi64x ( 0x00, 0x00, 0x00, 0x03 );
    __m256i b = shift_left_2 ( a );

    size_t * c = ( size_t * ) &b;

    print ( c [ 3 ] ); print ( c [ 2 ] ); print ( c [ 1 ] ); print ( c [ 0 ] ); putchar ( '\n' );

    return 0;
}

The above program does not give the expected (by me) output, as far as I can see. I'm stumped as to how these functions work together (read the descriptions). Am I doing something wrong, or is the implementation of shift_left() wrong?

EDIT1: I came to realize (and confirmed in the comments) that this code only intends to shift by max 32 (and are bytes), so it does not satisfy my goal. Which leaves the question, "How to implement lane crossing logical bit-wise shift (left and right) in AVX2".

EDIT2: Fast forward: In the meanwhile, I'm less stumped as to how it works and have coded what I needed. I've posted the code (shift and rotate) and accepted that as the answer.

like image 644
degski Avatar asked Mar 19 '18 16:03

degski


2 Answers

Probably not the kind of answer that you're expecting. But here's a reasonably efficient solution that actually works for a run-time shift amount.

The costs are:

  • Preprocess: ~12 - 14 instructions
  • Rotation: 5 instructions
  • Shift: 6 instructions

In order to shift or rotate anything, you must first preprocess the shift amount. Once you have that, you can efficiently perform shifts/rotations.

Because the preprocessing step is so expensive, this solution utilizes an object to hold the preprocessed shift amount so that it can be reused many times when shifting by the same amount.

For efficiency, the object should be on the stack in the same scope as the code that does the shifting. This allows the compiler to promote all the fields of the object into registers. Furthermore, it's recommended to force-inline all the methods of the class.

#include <stdint.h>
#include <immintrin.h>

class LeftShifter_AVX2{
public:
    LeftShifter_AVX2(uint32_t bits){
        //  Precompute all the necessary values.
        permL = _mm256_sub_epi32(
            _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7),
            _mm256_set1_epi32(bits / 32)
        );
        permR = _mm256_sub_epi32(permL, _mm256_set1_epi32(1));

        bits %= 32;
        shiftL = _mm_cvtsi32_si128(bits);
        shiftR = _mm_cvtsi32_si128(32 - bits);
        __m256i maskL = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permL);
        __m256i maskR = _mm256_cmpgt_epi32(_mm256_setzero_si256(), permR);
        mask = _mm256_or_si256(maskL, _mm256_srl_epi32(maskR, shiftR));
    }

    __m256i rotate(__m256i x) const{
        __m256i L = _mm256_permutevar8x32_epi32(x, permL);
        __m256i R = _mm256_permutevar8x32_epi32(x, permR);
        L = _mm256_sll_epi32(L, shiftL);
        R = _mm256_srl_epi32(R, shiftR);
        return _mm256_or_si256(L, R);
    }
    __m256i shift(__m256i x) const{
        return _mm256_andnot_si256(mask, rotate(x));
    }

private:
    __m256i permL;
    __m256i permR;
    __m128i shiftL;
    __m128i shiftR;
    __m256i mask;
};

Test Program:

#include <iostream>
using namespace std;

void print_u8(__m256i x){
    union{
        __m256i v;
        uint8_t s[32];
    };
    v = x;
    for (int c = 0; c < 32; c++){
        cout << (int)s[c] << " ";
    }
    cout << endl;
}

int main(){
    union{
        __m256i x;
        char buffer[32];
    };
    for (int c = 0; c < 32; c++){
        buffer[c] = (char)c;
    }
    print_u8(x);
    print_u8(LeftShifter_AVX2(0).shift(x));
    print_u8(LeftShifter_AVX2(8).shift(x));
    print_u8(LeftShifter_AVX2(32).shift(x));
    print_u8(LeftShifter_AVX2(40).shift(x));
}

Output:

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 
0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 
0 0 0 0 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

Right-shift is very similar. I'll leave that as an exercise for the reader.

like image 193
Mysticial Avatar answered Sep 20 '22 01:09

Mysticial


The following code implements lane-crossing logical bit-wise shift/rotate (left and right) in AVX2:

// Prototypes...

__m256i _mm256_sli_si256 ( __m256i, int );
__m256i _mm256_sri_si256 ( __m256i, int );
__m256i _mm256_rli_si256 ( __m256i, int );
__m256i _mm256_rri_si256 ( __m256i, int );


// Implementations...

__m256i left_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 0 ) ), _MM_SHUFFLE ( 3, 3, 3, 0 ) ) );
}

__m256i left_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 3, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 3, 0, 0 ) ); // 6

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i c = _mm256_srli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 0, 0 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 3, 3, 0, 0 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 3, 0, 0, 0 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i left_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_slli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 0 ) ), n ), _MM_SHUFFLE ( 3, 0, 0, 0 ) );
}

__m256i _mm256_sli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_shift_000_063 ( a, n ) : left_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? left_shift_128_191 ( a, n % 64 ) : left_shift_192_255 ( a, n % 64 );
}


__m256i right_shift_000_063 ( __m256i a, int n ) { // 6

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ), _MM_SHUFFLE ( 0, 3, 3, 3 ) ) );
}

__m256i right_shift_064_127 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 3, 2, 1 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 3, 3, 2 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 3, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 3, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_128_191 ( __m256i a, int n ) { // 7

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 3, 2, 3, 2 ) );

    __m256i c = _mm256_slli_epi64 ( a, 64 - n );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 3, 2, 1, 3 ) );

    __m256i f = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), d, _MM_SHUFFLE ( 0, 0, 3, 3 ) );
    __m256i g = _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), e, _MM_SHUFFLE ( 0, 0, 0, 3 ) );

    return _mm256_or_si256 ( f, g );
}

__m256i right_shift_192_255 ( __m256i a, int n ) { // 5

    return _mm256_blend_epi32 ( _mm256_setzero_si256 ( ), _mm256_srli_epi64 ( _mm256_permute4x64_epi64 ( a, _MM_SHUFFLE ( 0, 0, 0, 3 ) ), n ), _MM_SHUFFLE ( 0, 0, 0, 3 ) );
}

__m256i _mm256_sri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_shift_000_063 ( a, n ) : right_shift_064_127 ( a, n % 64 );
    else           return n < 192 ? right_shift_128_191 ( a, n % 64 ) : right_shift_192_255 ( a, n % 64 );
}


__m256i left_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i left_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 2, 1, 0, 3 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_slli_epi64 ( a, n );
    __m256i c = _mm256_srli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 0, 3, 2, 1 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i left_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i _mm256_rli_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? left_rotate_000_063 ( a, n ) : left_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? left_rotate_128_191 ( a, n % 64 ) : left_rotate_192_255 ( a, n % 64 );
}


__m256i right_rotate_000_063 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_srli_epi64 ( a, n ), _mm256_permute4x64_epi64 ( _mm256_slli_epi64 ( a, 64 - n ), _MM_SHUFFLE ( 0, 3, 2, 1 ) ) );
}

__m256i right_rotate_064_127 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 0, 3, 2, 1 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 1, 0, 3, 2 ) );

    return _mm256_or_si256 ( d, e );
}

__m256i right_rotate_128_191 ( __m256i a, int n ) { // 6

    __m256i b = _mm256_srli_epi64 ( a, n );
    __m256i c = _mm256_slli_epi64 ( a, 64 - n );

    __m256i d = _mm256_permute4x64_epi64 ( b, _MM_SHUFFLE ( 1, 0, 3, 2 ) );
    __m256i e = _mm256_permute4x64_epi64 ( c, _MM_SHUFFLE ( 2, 1, 0, 3 ) );

    return _mm256_or_si256 ( d, e );
}
__m256i right_rotate_192_255 ( __m256i a, int n ) { // 5

    return _mm256_or_si256 ( _mm256_slli_epi64 ( a, 64 - n ), _mm256_permute4x64_epi64 ( _mm256_srli_epi64 ( a, n ), _MM_SHUFFLE ( 2, 1, 0, 3 ) ) );
}

__m256i _mm256_rri_si256 ( __m256i a, int n ) {

    if ( n < 128 ) return n <  64 ? right_rotate_000_063 ( a, n      ) : right_rotate_064_127 ( a, n % 64 );
    else           return n < 192 ? right_rotate_128_191 ( a, n % 64 ) : right_rotate_192_255 ( a, n % 64 );
}

I have tried to make the _mm256_permute4x64_epi64 ops (when there in any case have to be two) to partially overlap, which should keep the overall latency to a minimum.

Most of the suggestions and or clues given by commenters were helpful in putting together the code, thanks to those. Obviously, improvements and or any other comments are welcome.

I think that Mystical's answer is interesting, but too complicated to be used effectively for generalized shifting/rotating for use f.e. in a library.

like image 35
degski Avatar answered Sep 19 '22 01:09

degski