Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Inform c or c++ compiler loop length is mutliple of 8

Tags:

c++

c

gcc

clang

I want to write the following function in c++ (compiling using gcc 11.1 with -O3 -mavx -std=c++17)

void f( float * __restrict__ a, float * __restrict__ b, float * __restrict__ c, int64_t n) {
    for (int64_t i = 0; i != n; ++i) {
        a[i] = b[i] + c[i];
    }
}

This generates about 60 lines of assembly, many of which deal with the case where n is not a multiple of 8. https://godbolt.org/z/61MYPG7an

I know that n is always a multiple of 8. One way I could change this code is to replace for (int64_t i = 0; i != n; ++i) with for (int64_t i = 0; i != (n / 8 * 8); ++i). This generates only about 20 assembly instructions. https://godbolt.org/z/vhvdKMfE9

However, on line 5 of the second godbolt link, there is an instruction to zero the lowest three bits of n. If there was a way to inform the compiler that n will always be a multiple of 8, then this instruction could be omitted with no change in behavior. Does anyone know of a way to do this on any c or c++ compiler (especially on gcc or clang)? In my case this doesn't actually matter, but I'm interested and not sure where to look.

like image 696
Henry Heffan Avatar asked Aug 20 '21 03:08

Henry Heffan


1 Answers

Declare the assumption with __builtin_unreachable

void f(float *__restrict__ a, float *__restrict__ b, float *__restrict__ c, int64_t n) {
    if(n % 8 != 0) __builtin_unreachable(); // control flow cannot reach this branch so the condition is not necessary and is optimized out
    for (int64_t i = 0; i != n; ++i) { // if control flow reaches this point n is a multiple of 8
        a[i] = b[i] + c[i];
    }
}

This produces much shorter code.

like image 161
HTNW Avatar answered Oct 18 '22 21:10

HTNW