Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementation of bit rotate operators using SIMD in CUDA

I know that StackOverflow is not meant for asking code to other persons, but let me speak.

I am trying to implement some AES functions in CUDA C++ device code. While trying to implement the left bytewise rotate operator, I was disconcerted to see that there was no native SIMD intrisic for that. So I began a naive implementation, but....it's huge, and while I haven't tried it yet, it just won't be fast because of the expensive unpacking/packing... So, is there a mean to do a per byte bit rotate operation that's at least somewhat efficient ?

Here's the code if you wan't to have a look.

__inline__ __device__ uint32_t per_byte_bit_left_rotate(uint32_t input, uint8_t amount) {
return ((((input & 0xFF) >> 0) << amount) | (((input & 0xFF) >> 0) >> 7) & ~0x100) << 0 |
     ((((input & 0xFF00) >> 8) << amount) | ((input & 0xFF00 >> 8) >> 7) & ~0x100) << 8 |
     ((((input & 0xFF0000) >> 16) << amount) | ((input & 0xFF0000 >> 16) >> 7) & ~0x100) << 16 |
     ((((input & 0xFF000000) >> 24) << amount) | ((input & 0xFF000000 >> 24) >> 7) & ~0x100) << 24; } // The XORs are for clearing the old 7th bit who is getting pushed to the next byte of the intermediate int
like image 476
Sachiko.Shinozaki Avatar asked Aug 27 '17 00:08

Sachiko.Shinozaki


2 Answers

CUDA has a __byte_perm() intrinsic that maps directly to the PRMT instruction at the machine code (SASS) level, which is a byte-wise permute instruction. It can be used to efficiently extract and merge bytes. To affect a byte-wise left rotation, we can double up each byte, shift byte-pairs by the desired amount, then extract and merge the four high-bytes of the byte pairs.

For byte-wise rotation, we only need the lowest three bits of the shift amount, as a rotation by s is the same as a rotation by s mod 8. For efficiency, it is best to avoid integer types comprising fewer than 32 bits, as C++ semantics require integer types narrower than int to be widened to int before use in expressions. This can and does incur conversion overhead on many architectures, including GPUs.

The throughput of the PRMT instruction is architecture dependent, so the use of __byte_perm() may lead to code that is faster or slower than use of the classical SIMD-in-a-register method demonstrated in another answer, so be sure to benchmark in the context of your use case prior to deployment.

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>

__device__ uint32_t per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
     uint32_t l = __byte_perm (input, 0, 0x1100) << (amount & 7);
     uint32_t h = __byte_perm (input, 0, 0x3322) << (amount & 7);
     return __byte_perm (l, h, 0x7531);
}

__global__ void rotl_kernel (uint32_t input, uint32_t amount, uint32_t *res)
{
    *res = per_byte_bit_left_rotate (input, amount);
}

uint32_t ref_per_byte_bit_left_rotate (uint32_t input, uint32_t amount)
{
   int s = amount & 7;
   uint8_t b0 = (input >>  0) & 0xff;
   uint8_t b1 = (input >>  8) & 0xff;
   uint8_t b2 = (input >> 16) & 0xff;
   uint8_t b3 = (input >> 24) & 0xff;
   b0 = s ? ((b0 << s) | (b0 >> (8 - s))) : b0;
   b1 = s ? ((b1 << s) | (b1 >> (8 - s))) : b1;
   b2 = s ? ((b2 << s) | (b2 >> (8 - s))) : b2;
   b3 = s ? ((b3 << s) | (b3 >> (8 - s))) : b3;
   return (b3 << 24) | (b2 << 16) | (b1 << 8) | (b0 << 0);
}

// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)

// Macro to catch CUDA errors in CUDA runtime calls
#define CUDA_SAFE_CALL(call)                                          \
do {                                                                  \
    cudaError_t err = call;                                           \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

// Macro to catch CUDA errors in kernel launches
#define CHECK_LAUNCH_ERROR()                                          \
do {                                                                  \
    /* Check synchronous errors, i.e. pre-launch */                   \
    cudaError_t err = cudaGetLastError();                             \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString(err) );       \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
    /* Check asynchronous errors, i.e. kernel failed (ULF) */         \
    err = cudaThreadSynchronize();                                    \
    if (cudaSuccess != err) {                                         \
        fprintf (stderr, "Cuda error in file '%s' in line %i : %s.\n",\
                 __FILE__, __LINE__, cudaGetErrorString( err) );      \
        exit(EXIT_FAILURE);                                           \
    }                                                                 \
} while (0)

int main (void)
{
    uint32_t arg, ref, res = 0, *res_d = 0;
    uint32_t shft;

    CUDA_SAFE_CALL (cudaMalloc ((void**)&res_d, sizeof(*res_d)));
    for (int i = 0; i < 100000; i++) {
        arg  = KISS;
        shft = KISS;
        ref = ref_per_byte_bit_left_rotate (arg, shft);
        rotl_kernel <<<1,1>>>(arg, shft, res_d);
        CHECK_LAUNCH_ERROR();
        CUDA_SAFE_CALL (cudaMemcpy (&res, res_d, sizeof (res), 
                                    cudaMemcpyDeviceToHost));
        if (res != ref) {
            printf ("!!!! arg=%08x shft=%d  res=%08x  ref=%08x\n", 
                    arg, shft, res, ref);
        }
    }
    CUDA_SAFE_CALL (cudaFree (res_d));
    CUDA_SAFE_CALL (cudaDeviceSynchronize());
    return EXIT_SUCCESS;
}
like image 94
njuffa Avatar answered Nov 11 '22 15:11

njuffa


The rotate count is the same for all elements, right?

Shift the whole input left and right, and then AND those with masks that zero all the bits that crossed a byte boundary, for all 4 bytes in one AND. I think amount is always a compile-time constant in AES, so you don't have to worry about the runtime cost of generating the masks on the fly. Just let the compiler do it. (IDK CUDA, but this appears to be the same problem as writing a SWAR bit-hack with 32-bit integers for normal C++)

This is based on the usual (x << count) | (x >> (32-count)) rotate idiom, with masking and a different right-shift count to make it into separate 8-bit rotates.

inline
uint32_t per_byte_bit_left_rotate(uint32_t input, unsigned amount)
{
    // With constant amount, the left/right masks are constants
    uint32_t rmask = 0xFF >> ((8 - amount) & 7);
    rmask = (rmask<<24 | rmask<<16 | rmask<<8 | rmask);
    uint32_t lmask = ~rmask;

    uint32_t lshift = input << amount;
    lshift &= lmask;
    if (amount == 1) {  // special case left-shift by 1 using an in-lane add instead of shift&mask
        lshift = __vadd4(input, input);
    }
    uint32_t rshift = input >> ((8 - amount) & 7);
    rshift &= rmask;

    uint32_t rotated = lshift | rshift;
    return rotated;
}

It might be even more efficient to mask the input one way before shifting, and mask the output after shifting ((in&lmask)<<amount | ((in>>(8-amount))&rmask), with a different lmask). NVidia hardware is in-order superscalar, and shifts have limited throughput. Doing it that way would be more likely to execute as two an independent shift+mask pairs.

(This doesn't try to avoid C++ UB with amount>=32. See Best practices for circular shift (rotate) operations in C++. In this case, I think changing to lshift = input << (amount & 7) would do the trick.

To test that this compiles efficiently, I looked at the clang -O3 asm output for x86-64 with a constant amount. The Godbolt compiler explorer has compilers for various architectures (not CUDA though), so click that link and flip to ARM, MIPS or PowerPC if you can read those asm languages more easily than x86.

uint32_t rol7(uint32_t a) {
    return per_byte_bit_left_rotate(a, 7);
}
    mov     eax, edi
    shl     eax, 7
    shr     edi
    and     eax, -2139062144   # 0x80808080
    and     edi, 2139062143    # 0x7F7F7F7F
    lea     eax, [rdi + rax]   # ADD = OR when no bits intersect
    ret

Perfect, exactly what I hoped for.

A couple test cases:

uint32_t test_rol() {
    return per_byte_bit_left_rotate(0x02ffff04, 0);
}
    // yup, returns the input with count=0
    // return 0x2FFFF04


uint32_t test2_rol() {
    return per_byte_bit_left_rotate(0x02f73804, 4);
}
    // yup, swaps nibbles
    // return 0x207F8340

This is the same kind of thing you need to do for 8-bit shifts with x86 SSE2 / AVX2, because the smallest shift granularity the hardware supports is 16-bit.

like image 36
Peter Cordes Avatar answered Nov 11 '22 15:11

Peter Cordes