Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimizing horizontal boolean reduction in ARM NEON

Tags:

simd

arm

neon

I'm experimenting with a cross-platform SIMD library ala ecmascript_simd aka SIMD.js, and part of this is providing a few "horizontal" SIMD operations. In particular, the API that library offers includes any(<boolN x M>) -> bool and all(<boolN x M>) -> bool functions, where <T x K> is a vector of K elements of type T and boolN is an N-bit boolean, i.e. all ones or all zeros, as SSE and NEON return for their comparison operations.

For example, let v be a <bool32 x 4> (a 128-bit vector), it could be the result of VCLT.S32 or something. I'd like to compute all(v) = v[0] && v[1] && v[2] && v[3] and any(v) = v[0] || v[1] || v[2] || v[3].

This is easy with SSE, e.g. movmskps will extract the high bit of each element, so all for the type above becomes (with C intrinsics):

#include<xmmintrin.h>
int all(__m128 x) {
    return _mm_movemask_ps(x) == 8 + 4 + 2 + 1;
}

and similarly for any.

I'm struggling to find obvious/nice/efficient ways to implement this with NEON, which doesn't support an instruction like movmskps. There's the approach of simply extracting each element and computing with scalars. E.g. there's the naive method but there's also the approach of using the "horizontal" operations NEON supports, like VPMAX and VPMIN.

#include<arm_neon.h>

int all_naive(uint32x4_t v) {
    return v[0] && v[1] && v[2] && v[3];
}
int all_horiz(uint32x4_t v) {
    uint32x2_t x = vpmin_u32(vget_low_u32(v),
                             vget_high_u32(v));
    uint32x2_t y = vpmin_u32(x, x);
    return x[0] != 0;
}

(One can do a similar thing for the latter with VPADD, which may be faster, but it's fundamentally the same idea.)

Are there are other tricks one can use to implement this?


Yes, I know that horizontal operations are not great with SIMD vector units. But sometimes it is useful, e.g. many SIMD implementations of mandlebrot will operate on 4 points at once, and bail out of the inner loop when all of them are out of range... which requires doing a comparison and then a horizontal and.

like image 652
huon Avatar asked Jul 03 '15 01:07

huon


1 Answers

This is my current solution that is implemented in eve library.

If your backend has C++20 support, you can just use the library: it has implementations for arm-v7, arm-v8 (only little endian at the moment) and all x86 from sse2 to avx-512. It's open source and MIT licensed. In beta at the moment. Feel free to reach out (for example with an issue) if you are trying out the library.

Take everything with a grain of salt - I don't yet have the arm benchmarks set up.

NOTE: On top of basic all and any we also have a movemask equivalent to do more complex operations like first_true. That wasn't part of the question and it's not amazing but the code can be found here

ARM-V7, 8 bytes register

Now, arm-v7 is 32 bit architecture, so we try to get to 32 bit elements where we can.

  • any

Use pairwise 32 bit max. If any element is true, the max is true.

// cast to dwords
dwords = vpmax_u32(dwords, dwords);
return vget_lane_u32(dwords, 0);
  • all

Pairwise min instead of max. Also what you test against changes. If you have 4 byte element - just test for true. If shorts or chars - you need to test for -1;

// cast to dwords
dwords = vpmin_u32(dwords, dwords);
std::uint32_t combined = vget_lane_u32(dwords, 0);

// Assuming T is your scalar type
if constexpr ( sizeof(T) >= 4 ) return combined;

// I decided that !~ is better than -1, compiler will figure it out.
return !~combined; 

ARM-V7, 16 bytes register

For anything bigger than chars, just do a conversion to a 64 bit one. Here is the list of vector narrow integer conversions.

For chars, the best I found is to reinterpret as uint32 and do an extra check. So compare for == -1 for all and > 0 for any. Seemed nicer asm the split in two 8 byte registers.

Then just do all/any on that dword register.

ARM-v8, 8 byte

ARM-v8 has 64 bit support, so you can just get a 64 bit lane. That one is trivially testable.

ARM-v8, 16 byte

We use vmaxvq_u32 since there is not a 64 bit one for any and vminvq_u32, vminvq_u16 or vminvq_u8 for all depending on the element size. (Which is similar to glibc strlen)

Conclusion

Lack of benchmarks definitely makes me worried, some instructions are problematic sometimes and I don't know about it. Regardless, that's the best I've got, so far at least.

like image 51
Denis Yaroshevskiy Avatar answered Nov 12 '22 06:11

Denis Yaroshevskiy