I have a problem for which I have eight elements that can contain 0, 1, or 2. I can easily represent this in 16 bits, but for SIMD efficiency reasons, I need it to occupy 13 bits (it is not the only thing present in the lane).
Fortunately, 2^13==8192, and 3^8==6561, so the states I want can fit. However, here's where things get interesting. Naively, I would just represent these states by counting the ternary numeral states. For example, to represent the tritmask 0t12211012 (I'll use this as an example thoughout), I could just write 0t12211012 = 2*3^0+1*3^1+0*3^2+1*3^3+1*3^4+2*3^5+2*3^6+1*3^7 = 4244 = 0b1000010010100.
I have a set of operations I need to support:
0t12211012 and I wish to place a 2 in the position holding a zero, I can simply add 0t200=18. (Note that the conversion to tritspace is easy, because I only have 8 trits, so I can store the base powers in a register and index it with pshufw).0t12211012, I want to be able to extract the bitmask for 0, which is 0b00000100, for 1, which is 0b10011010, and for 2, which is 0b01100001. This I have not figured out how to do, and is what I would like assistance with. How can I do this in a small number of operations suitable for x86 SIMD?Thank you!
Edit 11/18/20: To give an example of an approach I consider too slow: we can iteratively find the value mod 3 and divide by 3 to pull trits off the least-significant end of the representation, then assemble the mask that way. C++ snippet:
uint32_t trits = <something>;
uint8_t mask0 = 0, mask1 = 0, mask2 = 0;
for (uint8_t shift = 0; shift < 8; ++shift) {
const uint32_t remainder = trits % 3;
mask0 |= (!remainder) << shift;
mask1 |= (remainder == 1) << shift;
mask2 |= (remainder == 2) << shift;
trits /= 3;
}
When actually writing this in a SIMD language, we would use the standard multiply-and-shift trick for division by a constant. But you can see it's linear in the number of trits, and has a lot of ops per iteration. We could code-golf this down a bit, but I think it is fundamentally the wrong approach. It should ideally be possible to do something in parallel for each trit... but I don't see it.
Edit 11/20/20: I've made a halfhearted effort to apply Aha to this problem without success. Maybe an interesting subproblem to solve instead is - is there a short sequence of bitwise ops under the same constraints as above that acts as a 'ternary bitwise AND'? That is, an op that compares two encoded numbers in tritspace and returns a bitmask that is 1 when the corresponding trits are equal and zero otherwise? That would be a primitive from which we could build up the ops needed. We have left and right shift in tritspace (just multiply or divide by 3); and we have +/- a value. So what we are missing is the ability to test if trits are particular values...
The classic solution for a base change is to use divide and conquer and recursively transform the given number from basis b^(2^K) to b^(2^(K-1)) ... b^2, b^1. Here, b^2^k is larger than the number n that we wish to convert.
This recursive algorithm has the added advantage that at any given step k<K we have decomposed the original problem into 2^k subproblems of identical size and divisor, revealing a highly parallel operation.
In this case, we have a power of 2 basis exponent 8 = 2^3, and both our original base-3 encoding and its related base-4 "deflated" encoding neatly fit inside a 16-bit word, which lead to a straightforward and very regular implementation:
// converts a compact (base-3) 8-trit number to a sequence of concatenated trits in binary
uint16_t deflateTrits (uint16_t & trits8) {
uint16_t trits2x4;
// split 8 trits into 2x4 blocks by representing them in base pow(3,4)==81
trits2x4 = (trits8 / 81) << 8;
| (trits8 % 81);
// split 2x4 trits into 4x2 blocks by representing them in base pow(3,2)==9
uint16_t trits4x2 = 0;
for (int k=0; k<2; ++k) {
uint8_t src = (trits2x4 >> (8*k)) & 0x00FF;
uint8_t dst = ( (src/9)<<4 ) | (src % 9);
trits4x2 |= dst<<(8*k);
}
// every 4-bit block now contains the binary representation of a 2-trit number
// split 4x2 trits into 8x1 blocks by representing them in base pow(3,1)==3
uint16_t trits8x1 = 0;
for (int k=0; k<4; ++k) {
uint8_t src = (trits4x2 >> (4*k)) & 0x0F;
uint8_t dst = ( (src/3)<<2 ) | (src % 3);
trits1 |= dst<<(4*k);
}
return trits8x1;
}
This may look like substantially more work, but notice how the computations inside the for loops are completely independent -meaning they can be computed in parallel with SIMD or any other means- except for the final accumulation over OR, which is associative and commutative and can therefore be reduced from its partial contributions in O(lg N) vector-like operations.
This algorithm can be sped up by computing the divisors and remainders of every 2^n trit block in parallel with some bit wizardry using modular exponentiation. I doubt this improves the run time for such short word lengths, but it may actually pay off when dealing with 32-trit chains and 64-bit words or longer, or if fully vectorizing the code.
// converts a compact (base-3) 8-trit number to a sequence of concatenated trits in binary
uint16_t deflateTritsFast (uint16_t trits8) {
uint16_t trits2x4, trits4x2, trits8x1;
// split 8 trits into 2x4 blocks by representing them in base pow(3,4)==81
{
uint16_t hi, lo;
hi = (trits8 / 81);
lo = (trits8 % 81);
trits2x4 = hi<<8 | lo;
}
// split 2x4 trits into 4x2 blocks by representing them in base pow(3,2)==9
// perform both divisions with remainder in parallel
{
uint16_t hi, lo;
lo = trits2x4; // lo bits == remainders mod 9
// 64*a+8*b+c == a-b+c mod 9
lo = 0x0909 // add 9 to avoid negative remainders
+ ( lo & 0x0707)
+ ( (lo>>6) & 0x0101)
- ( (lo>>3) & 0x0707);
// r' <- r mod 9 , with 0<=r<=17
lo -= 0x09 * (
( (lo>>3) & ( lo | (lo>>1) | (lo>>2)) |
(lo>>4)
) & 0x0101);
// divisions are now exact within each block, so results don't spill over
hi = (trits2x4 - lo) / 9;
// interleave
trits4x2 = (hi << 4) | lo;
}
// every 4-bit block now contains the binary representation of a 2-trit number
// split 4x2 trits into 8x1 blocks by representing them in base pow(3,1)==3
{
// compute remainders for each nibble in parallel without dividing
uint16_t hi, lo;
lo = trits4x2;
// (4 a + b) == a+b mod 3
lo = (lo & 0x3333) + ( (lo>>2) & 0x3333);
// repeat the operation in case the previous sum produced a carry
lo = (lo & 0x3333) + ( (lo>>2) & 0x3333);
// 3==0 mod 3
// one last mod3 operation to get rid of 3's within 2-bit ints
// detect any 3s, generate a mask, then mask off the values
lo ^= ( 0x03 * ( (lo) & (lo>>1) & 0x1111) );
// divisions are now exact within each block, so results don't spill over
hi = (trits4x2 - lo) / 3;
// interleave
trits8x1 = (hi << 2) | lo;
}
return trits8x1;
}
Again, this may look like a lot more computational work, but notice how most of the operations here are fast and inexpensive, involving highly parallel bit manipulation (shifts and boolean expressions) which are typically executed in one single cycle 4 at a time, and a few additions/subtractions, each of which may be executed in a single cycle depending on the CPU architecture.
The real gain here is strength reduction in the operations and parallelism, that is, having reduced a dependency chain of N-1=7 sequential divisions to a more parallel less expensive chain of log(N)=3 divisions. This gain would get a lot more dramatic for longer trit chains. The case of 32 packed trits in a 64-bit integer, for instance, would result in a reduction from 31 divisions down to 5.
Once the base-3 encoded number is unpacked, creating the 0-, 1- or 2-mask is straightforward:
// find the mask of ternary digits d in the trit sequence
uint16_t generateMask (uint16_t trits, uint8_t d) {
// unpacked trits
uint16_t up3s = deflateTritsFast (trits);
d &= 0x03; // only consider lowest 2 bits
uint16_t dx = d * 0x5555; // copy (scatter) the digit to every position
dx = ~(up3s ^ dx); // negated difference mask
dx &= (dx>>1) & 0x5555; // set each block to 1 only if both bits are the same as d
return dx;
}
The ternary masks generated here contain interleaved 0s, which is more convenient that the alternative for sequence reconstruction and manipulation:
int main () {
const uint16_t trits = 0x12A4; // random example
// the "deflated" trit representation can be reconstructed from the ternary masks
assert( deflateTritsFast(trits) == 1*generateMask(trits,1) | 2*generateMask(trits,2) );
// check that the conjunction of all three masks spans all the ternary digit positions
assert( generateMask(trits,0) | generateMask(trits,1) | generateMask(trits,2) == 0x5555);
return 0;
}
As a final note, the code presented here has multiple hotspots for further optimization, which I didn't address in the interest of clarity and conciseness.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With