Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I do efficiently bitwise majority voting on 3, 5, 7, 9 inputs with SSE/SSE2/AVX/...?

I have several (e.g. 3, 5, 7 or 9) equally sized huge blocks of data (e.g. 100KB-100MB), and want to do bitwise majority voting, to get just one block with the most frequently used value for each bit. To speed this up, i would like to use SSE/SSE2/AVX/NEON/... CPU extensions.

I just tried it manually bitwise, and the result was very slow.

like image 524
Philipp Gühring Avatar asked Oct 26 '25 14:10

Philipp Gühring


1 Answers

If the number of source data streams is 3, 5, 7, or 9, one can take advantage of proven optimal or near optimal computation of majority-of-n via a majority-of-3 primitive, a.k.a. the ternary median operator ⟨xyz⟩. Knuth, TAOCP Vol. 4a, points out that any monotone Boolean function can be expressed entirely in terms of the ternary median operator and the constants 0 and 1.

Recent literature (see comments in code below) shows how to construct majority-of-7 in a proven optimal way from majority-of-3, requiring seven instances of the latter. The optimal construction of majority-of-9 in this way is still an open research problem, but a fairly efficient construction using thirteen majority-of-3 instances was found recently. The ISO-C99 code below was used to explore this approach.

Compiled with recent x86-64 toolchains and -march=skylake-avx512 the data throughput achieved is decent (tens of GB/sec) when run on recent x86-64 platforms, but not yet approaching system memory throughput, which would be the ultimate goal. The reason for this is that compilers are not yet capable of reliably mapping the majority-of-3 primitive to the vpternlogq instruction available with AVX512, where a majority-of-3 operation is expressible with exactly one such instruction. One would have to work around this by use of intrinsics or use of inline assembly.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define NBR_STREAMS (9)         // 3, 5, 7, or 9
#define NBR_BLOCKS  (10000000)
#define BENCH_ITERS (3)

#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

uint64_t maj3 (uint64_t a, uint64_t b, uint64_t c)
{
    return (((b & c) | a) & (b | c));
}

uint64_t maj5 (uint64_t a, uint64_t b, uint64_t c, uint64_t d, uint64_t e) 
{ 
    /* Knuth, TAOCP Vol. 4a, p. 64 */ 
    return maj3 (a, maj3 (c, d, e), maj3 (b, c, maj3 (b, d, e))); 
}

uint64_t maj7 (uint64_t a, uint64_t b, uint64_t c, uint64_t d, 
               uint64_t e, uint64_t f, uint64_t g) 
{ 
    /*
      Eleonora Testa, et al., "Mapping Monotone Boolean Functions into Majority."
      IEEE Transactions on Computers, Vol. 68, No. 5, May 2019, pp. 791-797.
    */
    uint64_t s = maj3 (a, c, d);
    uint64_t t = maj3 (e, f, g);
    return maj3 (b, maj3 (e, s, maj3 (f, g, s)), maj3 (d, t, maj3 (a, c, t)));
}

uint64_t maj9 (uint64_t a, uint64_t b, uint64_t c, uint64_t d, uint64_t e, 
               uint64_t f, uint64_t g, uint64_t h, uint64_t i)
{
    /* 
      Thomas Häner, Damian S. Steiger, Helmut G. Katzgraber, "Parallel Tempering
      for Logic Synthesis." arXiv.2311.12394, Nov. 21, 2023
    */
    uint64_t r = maj3 (g, d, c);
    uint64_t s = maj3 (g, e, b);
    uint64_t t = maj3 (i, f, a);
    uint64_t u = maj3 (r, s, h);
    uint64_t v = maj3 (d, h, t);
    uint64_t w = maj3 (c, d, h);
    uint64_t x = maj3 (i, a, u);
    uint64_t y = maj3 (c, v, t);
    uint64_t z = maj3 (y, e, g);
    return maj3 (maj3 (x, u, f), maj3 (z, b, y), maj3 (s, w, t));
}

/*
  https://groups.google.com/forum/#!original/comp.lang.c/qFv18ql_WlU/IK8KGZZFJx4J
  From: geo <[email protected]>
  Newsgroups: sci.math,comp.lang.c,comp.lang.fortran
  Subject: 64-bit KISS RNGs
  Date: Sat, 28 Feb 2009 04:30:48 -0800 (PST)

  This 64-bit KISS RNG has three components, each nearly
  good enough to serve alone.    The components are:
  Multiply-With-Carry (MWC), period (2^121+2^63-1)
  Xorshift (XSH), period 2^64-1
  Congruential (CNG), period 2^64
*/
static uint64_t kiss64_x = 1234567890987654321ULL;
static uint64_t kiss64_c = 123456123456123456ULL;
static uint64_t kiss64_y = 362436362436362436ULL;
static uint64_t kiss64_z = 1066149217761810ULL;
static uint64_t kiss64_t;
#define MWC64  (kiss64_t = (kiss64_x << 58) + kiss64_c, \
                kiss64_c = (kiss64_x >> 6), kiss64_x += kiss64_t, \
                kiss64_c += (kiss64_x < kiss64_t), kiss64_x)
#define XSH64  (kiss64_y ^= (kiss64_y << 13), kiss64_y ^= (kiss64_y >> 17), \
                kiss64_y ^= (kiss64_y << 43))
#define CNG64  (kiss64_z = 6906969069ULL * kiss64_z + 1234567ULL)
#define KISS64 (MWC64 + XSH64 + CNG64)

int main (void)
{
    double start, stop, elapsed, datasize;
    /* generate test data */
    printf ("starting data generation\n"); fflush(stdout);
    uint64_t *stream[NBR_STREAMS+1];
    for (int i = 0; i <= NBR_STREAMS; i++) {
        stream [i] = malloc (sizeof (uint64_t) * NBR_BLOCKS);
        if (stream[i] == 0) {
            printf ("allocation for stream %d failed\n", i);
            return EXIT_FAILURE;
        }
        for (int j = 0; j < NBR_BLOCKS; j++) {
            stream[i][j] = KISS64;
        }
    }
    printf ("data generation complete\n");
    /* compute bits of output stream as majority of bits of input streams */
    printf ("generate output stream; timed portion of code\n"); fflush(stdout);
    for (int n = 0; n < BENCH_ITERS; n++) {
        start = second();
        for (int j = 0; j < NBR_BLOCKS; j++) {
#if (NBR_STREAMS == 3)
            stream[3][j] = maj3 (stream[0][j], stream[1][j], stream[2][j]);
#elif (NBR_STREAMS == 5)
            stream[5][j] = maj5 (stream[0][j], stream[1][j], stream[2][j], 
                                 stream[3][j], stream[4][j]);
#elif (NBR_STREAMS == 7)
            stream[7][j] = maj7 (stream[0][j], stream[1][j], stream[2][j], 
                                 stream[3][j], stream[4][j], stream[5][j],
                                 stream[6][j]);
#elif (NBR_STREAMS == 9)
            stream[9][j] = maj9 (stream[0][j], stream[1][j], stream[2][j], 
                                 stream[3][j], stream[4][j], stream[5][j],
                                 stream[6][j], stream[7][j], stream[8][j]);
#else
#error unsupported N
#endif
        }
        stop = second();
    }
    elapsed = stop - start;
    datasize = sizeof (uint64_t) * NBR_BLOCKS * (NBR_STREAMS + 1);
    printf ("processed at %.3f GB/sec\n", datasize * 1e-9 / elapsed);
    printf ("checking output stream\n"); fflush(stdout);
    /* check result stream, the slow way */
    for (int j = 0; j < NBR_BLOCKS; j++) {
        uint64_t t[NBR_STREAMS+1];
        for (int i = 0; i <= NBR_STREAMS; i++) {
            t[i] = stream[i][j];
        }
        for (int k = 0; k < 64; k++) {
            int majority, bitcount = 0;
            for (int i = 0; i < NBR_STREAMS; i++) {
                bitcount += (t[i] >> k) & 1;
            }
            majority = bitcount > (NBR_STREAMS / 2);
            if (majority != ((t[NBR_STREAMS] >> k) & 1)) {
                printf ("error at block %d bit %d res=%d ref=%d\n",
                        j, k, majority, (int)((t[NBR_STREAMS] >> k) & 1));
                return EXIT_FAILURE;
            }
        }
    }
    printf ("test passed\n");
    return EXIT_SUCCESS;
}
like image 65
njuffa Avatar answered Oct 29 '25 07:10

njuffa