Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast bitwise operations on a long

I know I can write the following method to calculate the set bit indexes of a long :

private static List<Integer> bitPositions(long number) {
    final List<Integer> positions = new ArrayList<>();
    int position = 1;
    while (number != 0) {
        if ((number & 1L) != 0) {
            positions.add(position);
        }
        position++;
        number = number >>> 1;
    }
    return positions;
}

My question is : Is there a faster method to do this ?

like image 674
Amir Afghani Avatar asked Feb 05 '23 09:02

Amir Afghani


2 Answers

Fastest method

BitBank's answer to this question is about twice as fast as this answer's two methods just below this one. This steals the idea in BitBank's answer and makes it about 73% faster than it on my machine (9 times faster than the question's method) by using bit twiddling to repeatedly turn off the least-significant one bit rather than right-shifting the one bit off the right end and keeping track of how much shifting has occurred.

private static final byte[] bitPositions(long n) {
    final byte[] result = new byte[Long.bitCount(n)];

    for (int i = 0; n != 0L; i++) {
        result[i] = (byte) ((byte) Long.numberOfTrailingZeros(n) + 1);
        n &= n - 1L;  // Change least-significant one bit to a zero bit.
    }

    return result;
}

Improvements on BitBank's answer

  • No need to keep track of how many bits we've skipped past.
  • Quickly turns the last one bit to a zero bit.
  • The double cast to byte speeds things up slightly. I assume this is because it allows for byte-sized rather than int-sized arithmetic.

Manual fusion

As Durandal pointed out in the question's comments, you can trade the following:

for (int bitPosition : bitPositions(n)) {
    // Do something with `bitPosition`.
}

for a style that skips the method call and does this instead:

long temp = n;
while (temp != 0L) {
    int bitPosition = Long.numberOfTrailingZeros(temp) + 1;
    temp &= temp - 1L;  // Change least-significant one bit to a zero bit.

    // Do something with `bitPosition`.
}

Benefits of fusion

  • No time wasted calling a method.
  • No need to create or garbage collect an array, saving time and memory.
  • The bit position can probably remain in a very fast CPU register the whole time you're using it rather than potentially needing to write it to an array in RAM (which is much slower) and then to read it back from RAM later on.

Drawbacks of fusion

  • It's a bit uglier than making a clearly-named method call and cleanly using an array of results.
  • If you have multiple places in your code where you need to compute the bit positions of a number, you have to repeat code in each place (violates DRY).
  • If you want to iterate multiple separate times through the bit positions of the same number, you have to compute the bit positions over again rather than reusing a previously generated array.

    This might not be an actual drawback, though, if computing a bit position anew is quicker than loading a precomputed one from an array in RAM.


Slowest method

Here's a method that produces the same results (just in a byte[] instead of a List<Integer>) about twice as fast:

private static final byte[] bitPositions(long n) {
    final byte[] result = new byte[Long.bitCount(n)];

    int i = 0;
    for (byte bit = 1; n != 0L; bit++) {
        if ((n & 1L) != 0) result[i++] = bit;
        n >>>= 1;
    }

    return result;
}

I'd recommend changing byte bit = 1 in the for loop to byte bit = 0 to switch to the traditional method of numbering bit positions starting with zero instead of one.

Improvements

  • Precomputing the needed capacity with Long.bitCount(n) (uses the very speedy "popcount" instruction of your processor) speeds up your method quite a bit. You can change this by making the ArrayList using new ArrayList<>(Long.bitCount(n)).
  • Using an ArrayList<Integer> is slower than a byte[], because:
    • Time must be wasted looking up low-valued (-127 to 128) Integer values from the Integer cache to put them into the ArrayList.
    • Time must be wasted when using the ints stored in the resulting List<Integer> later on because you have to both retrieve the Integer from the List<Integer> and then retrieve the int from the Integer.
  • A byte[] uses about 1/4th (32-bit system) or 1/8th (64-bit system) the memory of an ArrayList<Integer>, since bytes are that much smaller than pointers to Integers.

Slightly faster than slowest method, but uglier

As suggested by another person's deleted answer, loop unrolling speeds things up a little bit further on my machine (check whether that's true on your machine as well):

private static final byte[] bitPositions(final long n) {
    final byte[] result = new byte[Long.bitCount(n)];

    int i = 0;
    if ((n &                    1L) != 0L) result[i++] = 1;
    if ((n &                    2L) != 0L) result[i++] = 2;
    if ((n &                    4L) != 0L) result[i++] = 3;
    if ((n &                    8L) != 0L) result[i++] = 4;
    if ((n &                   16L) != 0L) result[i++] = 5;
    if ((n &                   32L) != 0L) result[i++] = 6;
    if ((n &                   64L) != 0L) result[i++] = 7;
    if ((n &                  128L) != 0L) result[i++] = 8;
    if ((n &                  256L) != 0L) result[i++] = 9;
    if ((n &                  512L) != 0L) result[i++] = 10;
    if ((n &                 1024L) != 0L) result[i++] = 11;
    if ((n &                 2048L) != 0L) result[i++] = 12;
    if ((n &                 4096L) != 0L) result[i++] = 13;
    if ((n &                 8192L) != 0L) result[i++] = 14;
    if ((n &                16384L) != 0L) result[i++] = 15;
    if ((n &                32768L) != 0L) result[i++] = 16;
    if ((n &                65536L) != 0L) result[i++] = 17;
    if ((n &               131072L) != 0L) result[i++] = 18;
    if ((n &               262144L) != 0L) result[i++] = 19;
    if ((n &               524288L) != 0L) result[i++] = 20;
    if ((n &              1048576L) != 0L) result[i++] = 21;
    if ((n &              2097152L) != 0L) result[i++] = 22;
    if ((n &              4194304L) != 0L) result[i++] = 23;
    if ((n &              8388608L) != 0L) result[i++] = 24;
    if ((n &             16777216L) != 0L) result[i++] = 25;
    if ((n &             33554432L) != 0L) result[i++] = 26;
    if ((n &             67108864L) != 0L) result[i++] = 27;
    if ((n &            134217728L) != 0L) result[i++] = 28;
    if ((n &            268435456L) != 0L) result[i++] = 29;
    if ((n &            536870912L) != 0L) result[i++] = 30;
    if ((n &           1073741824L) != 0L) result[i++] = 31;
    if ((n &           2147483648L) != 0L) result[i++] = 32;
    if ((n &           4294967296L) != 0L) result[i++] = 33;
    if ((n &           8589934592L) != 0L) result[i++] = 34;
    if ((n &          17179869184L) != 0L) result[i++] = 35;
    if ((n &          34359738368L) != 0L) result[i++] = 36;
    if ((n &          68719476736L) != 0L) result[i++] = 37;
    if ((n &         137438953472L) != 0L) result[i++] = 38;
    if ((n &         274877906944L) != 0L) result[i++] = 39;
    if ((n &         549755813888L) != 0L) result[i++] = 40;
    if ((n &        1099511627776L) != 0L) result[i++] = 41;
    if ((n &        2199023255552L) != 0L) result[i++] = 42;
    if ((n &        4398046511104L) != 0L) result[i++] = 43;
    if ((n &        8796093022208L) != 0L) result[i++] = 44;
    if ((n &       17592186044416L) != 0L) result[i++] = 45;
    if ((n &       35184372088832L) != 0L) result[i++] = 46;
    if ((n &       70368744177664L) != 0L) result[i++] = 47;
    if ((n &      140737488355328L) != 0L) result[i++] = 48;
    if ((n &      281474976710656L) != 0L) result[i++] = 49;
    if ((n &      562949953421312L) != 0L) result[i++] = 50;
    if ((n &     1125899906842624L) != 0L) result[i++] = 51;
    if ((n &     2251799813685248L) != 0L) result[i++] = 52;
    if ((n &     4503599627370496L) != 0L) result[i++] = 53;
    if ((n &     9007199254740992L) != 0L) result[i++] = 54;
    if ((n &    18014398509481984L) != 0L) result[i++] = 55;
    if ((n &    36028797018963968L) != 0L) result[i++] = 56;
    if ((n &    72057594037927936L) != 0L) result[i++] = 57;
    if ((n &   144115188075855872L) != 0L) result[i++] = 58;
    if ((n &   288230376151711744L) != 0L) result[i++] = 59;
    if ((n &   576460752303423488L) != 0L) result[i++] = 60;
    if ((n &  1152921504606846976L) != 0L) result[i++] = 61;
    if ((n &  2305843009213693952L) != 0L) result[i++] = 62;
    if ((n &  4611686018427387904L) != 0L) result[i++] = 63;
    if ((n & -9223372036854775808L) != 0L) result[i++] = 64;

    return result;
}

You can also change this to start counting the bit positions at zero instead of one.

Improvements

  • Avoids the need to repeatedly perform >>> on the number.
  • Avoids the need to repeatedly perform ++ on the bit positon.
  • Avoids needing to check whether the number has reached zero.
  • Avoids some branch mispredictions.
like image 63
Chai T. Rex Avatar answered Feb 19 '23 14:02

Chai T. Rex


If you don't mind using intrinsics, you can have an even faster version. Long.numberOfTrailingZeros() will use the CPU intrinsic that counts the number of consecutive zero bits starting from the least-significant bit (the BSF instruction on x86 processors).

For a sparse value, this will be faster than all other looping methods because it doesn't have any conditionals or branches within the main loop, it skips runs of any number of 0's with a single iteration, and, for a 64-bit long, the BSF intrinsic has a latency of only 3 clock cycles on Intel Haswell CPUs.

private static final byte[] bitPositions(long n) {
    final byte[] result = new byte[Long.bitCount(n)];

    byte bitPosition = 0;
    for (int i = 0; n != 0L; i++) {
        final byte bitsToSkip = (byte) (Long.numberOfTrailingZeros(n) + 1);
        n >>>= bitsToSkip;
        bitPosition += bitsToSkip;
        result[i] = bitPosition;
    }

    return result;
}
like image 23
BitBank Avatar answered Feb 19 '23 16:02

BitBank