Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extract bit sequences of arbitrary length from byte[] array efficiently

I'm looking for the most efficient way of extracting (unsigned) bit sequences of arbitrary length (0 <= length <= 16) at arbitrary position. The skeleton class show how my current implementation essentially handles the problem:

public abstract class BitArray {

byte[] bytes = new byte[2048];
int bitGet;

public BitArray() {
}

public void readNextBlock(int initialBitGet, int count) {
    // substitute for reading from an input stream 
    for (int i=(initialBitGet>>3); i<=count; ++i) {
        bytes[i] = (byte) i;
    }
    prepareBitGet(initialBitGet, count);
}

public abstract void prepareBitGet(int initialBitGet, int count);

public abstract int getBits(int count);

static class Version0 extends BitArray {
    public void prepareBitGet(int initialBitGet, int count) {
        bitGet = initialBitGet;
    }

    public int getBits(int len) {
        // intentionally gives meaningless result
        bitGet += len;
        return 0;
    }
}

static class Version1 extends BitArray {
    public void prepareBitGet(int initialBitGet, int count) {
        bitGet = initialBitGet - 1;
    }

    public int getBits(int len) {
        int byteIndex = bitGet;
        bitGet = byteIndex + len;
        int shift = 23 - (byteIndex & 7) - len;
        int mask = (1 << len) - 1;
        byteIndex >>= 3;
        return (((bytes[byteIndex] << 16) | 
               ((bytes[++byteIndex] & 0xFF) <<  8) |
                (bytes[++byteIndex] & 0xFF)) >> shift) & mask;
    }
}

static class Version2 extends BitArray {
    static final int[] mask = { 0x0, 0x1, 0x3, 0x7, 0xF, 0x1F, 0x3F, 0x7F, 0xFF,
                0x1FF, 0x3FF, 0x7FF, 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF };

    public void prepareBitGet(int initialBitGet, int count) {
        bitGet = initialBitGet;
    }

    public int getBits(int len) {
        int offset = bitGet;
        bitGet = offset + len;
        int byteIndex = offset >> 3; // originally used /8
        int bitIndex = offset & 7;   // originally used %8
        if ((bitIndex + len) > 16) {
            return ((bytes[byteIndex] << 16 |
                    (bytes[byteIndex + 1] & 0xFF) << 8 |
                    (bytes[byteIndex + 2] & 0xFF)) >> (24 - bitIndex - len)) & mask[len];
        } else if ((offset + len) > 8) {
            return ((bytes[byteIndex] << 8 |
                    (bytes[byteIndex + 1] & 0xFF)) >> (16 - bitIndex - len)) & mask[len];
        } else {
            return (bytes[byteIndex] >> (8 - offset - len)) & mask[len];
        }
    }
}

static class Version3 extends BitArray {
    int[] ints = new int[2048];

    public void prepareBitGet(int initialBitGet, int count) {
        bitGet = initialBitGet;
        int put_i = (initialBitGet >> 3) - 1;
        int get_i = put_i;
        int buf;
        buf = ((bytes[++get_i] & 0xFF) << 16) |
              ((bytes[++get_i] & 0xFF) <<  8) |
               (bytes[++get_i] & 0xFF);
        do {
            buf = (buf << 8) | (bytes[++get_i] & 0xFF);
            ints[++put_i] = buf;
        } while (get_i < count);
    }

    public int getBits(int len) {
        int bit_idx = bitGet;
        bitGet = bit_idx + len;
        int shift = 32 - (bit_idx & 7) - len;
        int mask = (1 << len) - 1;
        int int_idx = bit_idx >> 3;
        return (ints[int_idx] >> shift) & mask;
    }
}

static class Version4 extends BitArray {
    int[] ints = new int[1024];

    public void prepareBitGet(int initialBitGet, int count) {
        bitGet = initialBitGet;
        int g = initialBitGet >> 3;
        int p = (initialBitGet >> 4) - 1;
        final byte[] b = bytes;
        int t = (b[g]  <<  8) | (b[++g] & 0xFF);
        final int[] i = ints;
        do {
            i[++p] = (t = (t << 16) | ((b[++g] & 0xFF) <<8) | (b[++g] & 0xFF));
        } while (g < count);
    }

    public int getBits(final int len) {
        final int i;
        bitGet = (i = bitGet) + len;
        return (ints[i >> 4] >> (32 - len - (i & 15))) & ((1 << len) - 1);
    }
}

public void benchmark(String label) {
    int checksum = 0;
    readNextBlock(32, 1927);
    long time = System.nanoTime();
    for (int pass=1<<18; pass>0; --pass) {
        prepareBitGet(32, 1927);
        for (int i=2047; i>=0; --i) {
            checksum += getBits(i & 15);
        }
    }
    time = System.nanoTime() - time;
    System.out.println(label+" took "+Math.round(time/1E6D)+" ms, checksum="+checksum);
    try { // avoid having the console interfere with our next measurement
        Thread.sleep(369);
    } catch (InterruptedException e) {}
}

public static void main(String[] argv) {
    BitArray test;
    // for the sake of getting a little less influence from the OS for stable measurement
    Thread.currentThread().setPriority(Thread.MAX_PRIORITY);
    while (true) {
        test = new Version0();
        test.benchmark("no implementaion");
        test = new Version1();
        test.benchmark("Durandal's (original)");
        test = new Version2();
        test.benchmark("blitzpasta's (adapted)");
        test = new Version3();
        test.benchmark("MSN's (posted)");
        test = new Version4();
        test.benchmark("MSN's (half-buffer modification)");
        System.out.println("--- next pass ---");
    }
}
}

This works, but I'm looking for a more efficient solution (performance wise). The byte array is guaranteed to be relatively small, between a few bytes up to a max of ~1800 bytes. The array is read exactly once (completely) between each call to the read method. There is no need for any error checking in getBits(), such as exceeding the array etc.


It seems my initial question above isn't clear enough. A "bit sequence" of N bits forms an integer of N bits, and I need to extract those integers with minimal overhead. I have no use for strings, as the values are either used as lookup indices or are directly fed into some computation. So basically, the skeleton shown above is a real class and getBits() signature shows how the rest of the code interacts with it.


Extendet the example code into a microbenchmark, included blitzpasta's solution (fixed missing byte masking). On my old AMD box it turns out as ~11400ms vs ~38000ms. FYI: Its the divide and modulo operations that kill the performance. If you replace /8 with >>3 and %8 with &7, both solutions are pretty close to each other (jdk1.7.0ea104).


There seemed to be a bit confusion about how and what to work on. The first, original post of the example code included a read() method to indicate where and when the byte buffer was filled. This got lost when the code was turned into the microbench. I re-introduced it to make this a little clearer. The idea is to beat all existing versions by adding another subclass of BitArray which need to implement getBits() and prepareBitGet(), the latter may be empty. Do not change the benchmarking to give your solution an advantage, the same could be done for all the existing solutions, making this a completely moot optimization! (really!!)

I added a Version0, which does nothing but increment the bitGet state. It always returns 0 to get a rough idea how big the benchmark overhead is. Its only there for comparison.

Also, an adaption on MSN's idea was added (Version3). To keep things fair and comparable for all competitors, the byte array filling is now part of the benchmark, as well as a preparatory step (see above). Originally MSN's solution did not do so well, there was lots of overhead in preparing the int[] buffer. I took the liberty of optimizing the step a little, which turned it into a fierce competitor :) You might also find that I de-convoluted your code a little. Your getBit() could be condensed into a 3-liner, probably shaving off one or two percent. I deliberately did this to keep the code readable and because the other versions aren't as condensed as possible either (again for readability).


Conclusion (code example above update to include versions based on all applicable contributions). On my old AMD box (Sun JRE 1.6.0_21), they come out as:

V0 no implementaion took 5384 ms
V1 Durandal's (original) took 10283 ms
V2 blitzpasta's (adapted) took 12212 ms
V3 MSN's (posted) took 11030 ms
V4 MSN's (half-buffer modification) took 9700 ms

Notes: In this benchmark an average of 7.5 bits is fetched per call to getBits(), and each bit is only read once. Since V3/V4 have to pay a high initialization cost, they tend to show better runtime behavior with more, shorter fetches (and consequently worse the closer to the maximum of 16 the average fetch size gets). Still, V4 stays slightly ahead of all others in all scenarios. In an actual application, the cache contention must be taken into account, since the extra space needed for V3/v4 may increase cache misses to a point where V0 would be a better choice. If the array is to be traversed more than once, V4 should be favored, since it fetches faster than every other and the costly initialization is amortized after the fist pass.

like image 356
Durandal Avatar asked Oct 02 '10 17:10

Durandal


1 Answers

If you just want the unsigned bit sequence as an int.

static final int[] lookup = {0x0, 0x1, 0x3, 0x7, 0xF, 0x1F, 0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF, 0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF };

/*
 * bytes: byte array, with the bits indexed from 0 (MSB) to (bytes.length * 8 - 1) (LSB)
 * offset: index of the MSB of the bit sequence.
 * len: length of bit sequence, must from range [0,16].
 * Not checked for overflow
 */
static int getBitSeqAsInt(byte[] bytes, int offset, int len){

    int byteIndex = offset / 8;
    int bitIndex = offset % 8;
    int val;

    if ((bitIndex + len) > 16) {
        val = ((bytes[byteIndex] << 16 | bytes[byteIndex + 1] << 8 | bytes[byteIndex + 2]) >> (24 - bitIndex - len)) & lookup[len];
    } else if ((offset + len) > 8) {
        val = ((bytes[byteIndex] << 8 | bytes[byteIndex + 1]) >> (16 - bitIndex - len)) & lookup[len];
    } else {
        val = (bytes[byteIndex] >> (8 - offset - len)) & lookup[len];
    }

    return val;
}

If you want it as a String (modification of Margus' answer).

static String getBitSequence(byte[] bytes, int offset, int len){

    int byteIndex = offset / 8;
    int bitIndex = offset % 8;
    int count = 0;
    StringBuilder result = new StringBuilder();        

    outer:
    for(int i = byteIndex; i < bytes.length; ++i) {
        for(int j = (1 << (7 - bitIndex)); j > 0; j >>= 1) {
            if(count == len) {
                break outer;
            }                
            if((bytes[byteIndex] & j) == 0) {
                result.append('0');
            } else {
                result.append('1');
            }
            ++count;
        }
        bitIndex = 0;
    }
    return  result.toString();
}   
like image 177
blizpasta Avatar answered Nov 01 '22 00:11

blizpasta