I'm new to using intrinsics but I wanted to write a function that takes a vector of 4 doubles computes a > 1e-5 ? std::sqrt(a) : 0.0
my first instinct was to write this as follows
#include <immintrin.h>
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
return is_valid * sqrt_val;
}
which according to gcc.godbolt.com compiles to the following
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
but i'm worried what will happen if sqrt_val
contains a nan
. i dont think 0.0 * nan
will work. what are the best practices to do here?
Edit
After reading the comment from @ChrisCooper (and @njuffa) I was linked to another stack overflow answer and so I will test for self equality and then and
this with my result.
#include <immintrin.h>
__m256d f(__m256d a)
{
__m256d is_valid = a > _mm256_set1_pd(1e-5);
__m256d sqrt_val = _mm256_sqrt_pd(a);
__m256d result = is_valid * sqrt_val;
__m256d cmpeq = result == result;
return _mm256_and_pd(cmpeq, result);
}
which compiles to the following
f(double __vector(4)):
vsqrtpd ymm1, ymm0
vcmpgtpd ymm0, ymm0, YMMWORD PTR .LC0[rip]
vmulpd ymm0, ymm1, ymm0
vcmpeqpd ymm1, ymm0, ymm0
vandpd ymm0, ymm1, ymm0
ret
.LC0:
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
.long 2296604913
.long 1055193269
I haven't programmed with AVX intrinsics before, so gathering information from documentation to quickly put together the code below. It seems to work as desired for the one test case I provided.
The relevant observation is that the comparison instructions return a mask of all 1s (if result is TRUE) or all 0s (if result is FALSE). This mask can then be used to conditionally set the result of the square root to zero by ANDing the mask with the result from vsqrtpd
. The binary representation of 0.0
in IEEE-754 double precision is all 0s.
Not having used these intrinsics before, I found the comparison predicates tricky to use. From what I understand, here we want to use the ordered comparison to get the desired behavior with respect to NaNs (that is, a comparison with a NaN should result in FALSE), so the 'O' variant. We also do not want a NaN input to trigger an exception (that is, we want the comparison to be quiet in that case), so the 'Q' variant. This means we want to use _CMP_GT_OQ.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <immintrin.h>
__m256d f (__m256d a)
{
double em5 = 1e-5;
__m256d v_em5 = _mm256_broadcast_sd (&em5);
__m256d v_sqrt = _mm256_sqrt_pd (a);
__m256d v_mask = _mm256_cmp_pd (a, v_em5, _CMP_GT_OQ);
__m256d v_res = _mm256_and_pd (v_sqrt, v_mask);
return v_res;
}
int main (void)
{
__m256d arg, res;
double args[4] = {2e-5, sqrt(-1.0), 1e-6, -1.0};
double ress [4] = {0};
memcpy (&arg, args, sizeof(arg));
res = f (arg);
memcpy (ress, &res, sizeof(res));
printf ("args = % 23.16e % 23.16e % 23.16e % 23.16e\n",
args[0], args[1], args[2], args[3]);
printf ("ress = % 23.16e % 23.16e % 23.16e % 23.16e\n",
ress[0], ress[1], ress[2], ress[3]);
return EXIT_SUCCESS;
}
I compiled the above program with the Intel C compiler, the output looks like this:
args = 2.0000000000000002e-005 -1.#IND000000000000e+000 9.9999999999999995e-007 -1.0000000000000000e+000
ress = 4.4721359549995798e-003 0.0000000000000000e+000 0.0000000000000000e+000 0.0000000000000000e+000
Here, 1.#IND000000000000e+000
is a specific QNaN called INDEFINITE.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With