Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert _mm_shuffle_ps SSE intrinsic to NEON intrinsic?

Tags:

simd

arm

sse

neon

I am trying to convert codes written in SSE to NEON SIMD and got stuck because of the _mm_shuffle_ps SSE intrinsic. Here is the code:

b = _mm_shuffle_ps(a, b, 136);

a, b, c are all the __m128 registers.

Now I want to use NEON to implement the same function. Assume that there are 3 float32x4_t vectors : x, y, z. I want to assign the 1st and 3rd lane of x to the 1st and 2nd lane of z respectively, and assign the 1st and 3rd lane of y to the 3rd and 4th lane of z respectively.

I can't find an efficient way to implement the function and need some help.

like image 869
CJZ Avatar asked Sep 12 '15 07:09

CJZ


1 Answers

There's no equivalent to _mm_shuffle_ps, but as noted you can use vtbl.

For DirectXMath, I use VTBL for generic XMVectorSwizzle and XMVectorPermute. I then specialize the template for for ARM/ARM64 with some patterns that can be done efficiently in ARM-NEON.

XMVectorSwizzle

inline XMVECTOR XMVectorSwizzle(FXMVECTOR V, uint32_t E0, uint32_t E1, uint32_t E2, uint32_t E3)
{
    assert( (E0 < 4) && (E1 < 4) && (E2 < 4) && (E3 < 4) );

    static const uint32_t ControlElement[ 4 ] =
    {
        0x03020100, // XM_SWIZZLE_X
        0x07060504, // XM_SWIZZLE_Y
        0x0B0A0908, // XM_SWIZZLE_Z
        0x0F0E0D0C, // XM_SWIZZLE_W
    };

    int8x8x2_t tbl;
    tbl.val[0] = vget_low_f32(V);
    tbl.val[1] = vget_high_f32(V);

    uint32x2_t idx = vcreate_u32( ((uint64_t)ControlElement[E0]) | (((uint64_t)ControlElement[E1]) << 32) );
    const uint8x8_t rL = vtbl2_u8( tbl, idx );

    idx = vcreate_u32( ((uint64_t)ControlElement[E2]) | (((uint64_t)ControlElement[E3]) << 32) );
    const uint8x8_t rH = vtbl2_u8( tbl, idx );

    return vcombine_f32( rL, rH );
}

template<uint32_t SwizzleX, uint32_t SwizzleY, uint32_t SwizzleZ, uint32_t SwizzleW>
inline XMVECTOR XMVectorSwizzle(FXMVECTOR V)
{
    return XMVectorSwizzle( V, SwizzleX, SwizzleY, SwizzleZ, SwizzleW );
}

template<> inline XMVECTOR XMVectorSwizzle<0,1,2,3>(FXMVECTOR V) { return V; }
template<> inline XMVECTOR XMVectorSwizzle<0,0,0,0>(FXMVECTOR V) { return vdupq_lane_f32( vget_low_f32(V), 0); }
template<> inline XMVECTOR XMVectorSwizzle<1,1,1,1>(FXMVECTOR V) { return vdupq_lane_f32( vget_low_f32(V), 1); }
template<> inline XMVECTOR XMVectorSwizzle<2,2,2,2>(FXMVECTOR V) { return vdupq_lane_f32( vget_high_f32(V), 0); }
template<> inline XMVECTOR XMVectorSwizzle<3,3,3,3>(FXMVECTOR V) { return vdupq_lane_f32( vget_high_f32(V), 1); }

template<> inline XMVECTOR XMVectorSwizzle<1,0,3,2>(FXMVECTOR V) { return vrev64q_f32(V); }

template<> inline XMVECTOR XMVectorSwizzle<0,1,0,1>(FXMVECTOR V) { float32x2_t vt = vget_low_f32(V); return vcombine_f32( vt, vt ); }
template<> inline XMVECTOR XMVectorSwizzle<2,3,2,3>(FXMVECTOR V) { float32x2_t vt = vget_high_f32(V); return vcombine_f32( vt, vt ); }
template<> inline XMVECTOR XMVectorSwizzle<1,0,1,0>(FXMVECTOR V) { float32x2_t vt = vrev64_f32( vget_low_f32(V) ); return vcombine_f32( vt, vt ); }
template<> inline XMVECTOR XMVectorSwizzle<3,2,3,2>(FXMVECTOR V) { float32x2_t vt = vrev64_f32( vget_high_f32(V) ); return vcombine_f32( vt, vt ); }

template<> inline XMVECTOR XMVectorSwizzle<0,1,3,2>(FXMVECTOR V) { return vcombine_f32( vget_low_f32(V), vrev64_f32( vget_high_f32(V) ) ); }
template<> inline XMVECTOR XMVectorSwizzle<1,0,2,3>(FXMVECTOR V) { return vcombine_f32( vrev64_f32( vget_low_f32(V) ), vget_high_f32(V) ); }
template<> inline XMVECTOR XMVectorSwizzle<2,3,1,0>(FXMVECTOR V) { return vcombine_f32( vget_high_f32(V), vrev64_f32( vget_low_f32(V) ) ); }
template<> inline XMVECTOR XMVectorSwizzle<3,2,0,1>(FXMVECTOR V) { return vcombine_f32( vrev64_f32( vget_high_f32(V) ), vget_low_f32(V) ); }
template<> inline XMVECTOR XMVectorSwizzle<3,2,1,0>(FXMVECTOR V) { return vcombine_f32( vrev64_f32( vget_high_f32(V) ), vrev64_f32( vget_low_f32(V) ) ); }

template<> inline XMVECTOR XMVectorSwizzle<0,0,2,2>(FXMVECTOR V) { return vtrnq_f32(V,V).val[0]; }
template<> inline XMVECTOR XMVectorSwizzle<1,1,3,3>(FXMVECTOR V) { return vtrnq_f32(V,V).val[1]; }

template<> inline XMVECTOR XMVectorSwizzle<0,0,1,1>(FXMVECTOR V) { return vzipq_f32(V,V).val[0]; }
template<> inline XMVECTOR XMVectorSwizzle<2,2,3,3>(FXMVECTOR V) { return vzipq_f32(V,V).val[1]; }

template<> inline XMVECTOR XMVectorSwizzle<0,2,0,2>(FXMVECTOR V) { return vuzpq_f32(V,V).val[0]; }
template<> inline XMVECTOR XMVectorSwizzle<1,3,1,3>(FXMVECTOR V) { return vuzpq_f32(V,V).val[1]; }

template<> inline XMVECTOR XMVectorSwizzle<1,2,3,0>(FXMVECTOR V) { return vextq_f32(V, V, 1); }
template<> inline XMVECTOR XMVectorSwizzle<2,3,0,1>(FXMVECTOR V) { return vextq_f32(V, V, 2); }
template<> inline XMVECTOR XMVectorSwizzle<3,0,1,2>(FXMVECTOR V) { return vextq_f32(V, V, 3); }

XMVectorPermute

inline XMVECTOR XMVectorPermute(FXMVECTOR V1, FXMVECTOR V2, uint32_t PermuteX, uint32_t PermuteY, uint32_t PermuteZ, uint32_t PermuteW)
{
    assert( PermuteX <= 7 && PermuteY <= 7 && PermuteZ <= 7 && PermuteW <= 7 );

    static const uint32_t ControlElement[ 8 ] =
    {
        0x03020100, // XM_PERMUTE_0X
        0x07060504, // XM_PERMUTE_0Y
        0x0B0A0908, // XM_PERMUTE_0Z
        0x0F0E0D0C, // XM_PERMUTE_0W
        0x13121110, // XM_PERMUTE_1X
        0x17161514, // XM_PERMUTE_1Y
        0x1B1A1918, // XM_PERMUTE_1Z
        0x1F1E1D1C, // XM_PERMUTE_1W
    };

    int8x8x4_t tbl;
    tbl.val[0] = vget_low_f32(V1);
    tbl.val[1] = vget_high_f32(V1);
    tbl.val[2] = vget_low_f32(V2);
    tbl.val[3] = vget_high_f32(V2);

    uint32x2_t idx = vcreate_u32( ((uint64_t)ControlElement[PermuteX]) | (((uint64_t)ControlElement[PermuteY]) << 32) );
    const uint8x8_t rL = vtbl4_u8( tbl, idx );

    idx = vcreate_u32( ((uint64_t)ControlElement[PermuteZ]) | (((uint64_t)ControlElement[PermuteW]) << 32) );
    const uint8x8_t rH = vtbl4_u8( tbl, idx );

    return vcombine_f32( rL, rH );
}

template<uint32_t PermuteX, uint32_t PermuteY, uint32_t PermuteZ, uint32_t PermuteW>
inline XMVECTOR XMVectorPermute(FXMVECTOR V1, FXMVECTOR V2)
{
    return XMVectorPermute( V1, V2, PermuteX, PermuteY, PermuteZ, PermuteW );
}

template<> inline XMVECTOR XMVectorPermute<0,1,2,3>(FXMVECTOR V1, FXMVECTOR V2) { (V2); return V1; }
template<> inline XMVECTOR XMVectorPermute<4,5,6,7>(FXMVECTOR V1, FXMVECTOR V2) { (V1); return V2; }

template<> inline XMVECTOR XMVectorPermute<0,1,4,5>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_low_f32(V1), vget_low_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<1,0,4,5>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_low_f32(V1) ), vget_low_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<0,1,5,4>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_low_f32(V1), vrev64_f32( vget_low_f32(V2) ) ); }
template<> inline XMVECTOR XMVectorPermute<1,0,5,4>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_low_f32(V1) ), vrev64_f32( vget_low_f32(V2) ) ); }

template<> inline XMVECTOR XMVectorPermute<2,3,6,7>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_high_f32(V1), vget_high_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<3,2,6,7>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_high_f32(V1) ), vget_high_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<2,3,7,6>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_high_f32(V1), vrev64_f32( vget_high_f32(V2) ) ); }
template<> inline XMVECTOR XMVectorPermute<3,2,7,6>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_high_f32(V1) ), vrev64_f32( vget_high_f32(V2) ) ); }

template<> inline XMVECTOR XMVectorPermute<0,1,6,7>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_low_f32(V1), vget_high_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<1,0,6,7>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_low_f32(V1) ), vget_high_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<0,1,7,6>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_low_f32(V1), vrev64_f32( vget_high_f32(V2) ) ); }
template<> inline XMVECTOR XMVectorPermute<1,0,7,6>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_low_f32(V1) ), vrev64_f32( vget_high_f32(V2) ) ); }

template<> inline XMVECTOR XMVectorPermute<3,2,4,5>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_high_f32(V1) ), vget_low_f32(V2) ); }
template<> inline XMVECTOR XMVectorPermute<2,3,5,4>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vget_high_f32(V1), vrev64_f32( vget_low_f32(V2) ) ); }
template<> inline XMVECTOR XMVectorPermute<3,2,5,4>(FXMVECTOR V1, FXMVECTOR V2) { return vcombine_f32( vrev64_f32( vget_high_f32(V1) ), vrev64_f32( vget_low_f32(V2) ) ); }

template<> inline XMVECTOR XMVectorPermute<0,4,2,6>(FXMVECTOR V1, FXMVECTOR V2) { return vtrnq_f32(V1,V2).val[0]; }
template<> inline XMVECTOR XMVectorPermute<1,5,3,7>(FXMVECTOR V1, FXMVECTOR V2) { return vtrnq_f32(V1,V2).val[1]; }

template<> inline XMVECTOR XMVectorPermute<0,4,1,5>(FXMVECTOR V1, FXMVECTOR V2) { return vzipq_f32(V1,V2).val[0]; }
template<> inline XMVECTOR XMVectorPermute<2,6,3,7>(FXMVECTOR V1, FXMVECTOR V2) { return vzipq_f32(V1,V2).val[1]; }

template<> inline XMVECTOR XMVectorPermute<0,2,4,6>(FXMVECTOR V1, FXMVECTOR V2) { return vuzpq_f32(V1,V2).val[0]; }
template<> inline XMVECTOR XMVectorPermute<1,3,5,7>(FXMVECTOR V1, FXMVECTOR V2) { return vuzpq_f32(V1,V2).val[1]; }

template<> inline XMVECTOR XMVectorPermute<1,2,3,4>(FXMVECTOR V1, FXMVECTOR V2) { return vextq_f32(V1, V2, 1); }
template<> inline XMVECTOR XMVectorPermute<2,3,4,5>(FXMVECTOR V1, FXMVECTOR V2) { return vextq_f32(V1, V2, 2); }
template<> inline XMVECTOR XMVectorPermute<3,4,5,6>(FXMVECTOR V1, FXMVECTOR V2) { return vextq_f32(V1, V2, 3); }
like image 182
Chuck Walbourn Avatar answered Oct 16 '22 20:10

Chuck Walbourn