Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why isn't there a branch prediction failure penalty in this Rust code?

I've written this very simple Rust function:

fn iterate(nums: &Box<[i32]>) -> i32 {
    let mut total = 0;
    let len = nums.len();
    for i in 0..len {
        if nums[i] > 0 {
            total += nums[i];
        } else {
            total -= nums[i];
        }
    }

    total
}

I've written a basic benchmark that invokes the method with an ordered array and a shuffled one:

fn criterion_benchmark(c: &mut Criterion) {
    const SIZE: i32 = 1024 * 1024;

    let mut group = c.benchmark_group("Branch Prediction");

    // setup benchmarking for an ordered array
    let mut ordered_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        ordered_nums.push(i - SIZE/2);
    }
    let ordered_nums = ordered_nums.into_boxed_slice();
    group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));

    // setup benchmarking for a shuffled array
    let mut shuffled_nums: Vec<i32> = vec![];
    for i in 0..SIZE {
        shuffled_nums.push(i - SIZE/2);
    }
    let mut rng = thread_rng();
    let mut shuffled_nums = shuffled_nums.into_boxed_slice();
    shuffled_nums.shuffle(&mut rng);
    group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));

    group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

I'm surprised that the two benchmarks have almost exactly the same runtime, while a similar benchmark in Java shows a distinct difference between the two, presumably due to branch prediction failure in the shuffled case.

I've seen mention of conditional move instructions, but if I otool -tv the executable (I'm running on a Mac), I don't see any in the iterate method output.

Can anyone shed light on why there's no perceptible performance difference between the ordered and the unordered cases in Rust?

like image 616
Dathan Avatar asked Jan 04 '20 06:01

Dathan


1 Answers

Summary: LLVM was able to remove/hide the branch by using either the cmov instruction or a really clever combination of SIMD instructions.


I used Godbolt to view the full assembly (with -C opt-level=3). I will explain the important parts of the assembly below.

It starts like this:

        mov     r9, qword ptr [rdi + 8]         ; r9 = nums.len()
        test    r9, r9                          ; if len == 0
        je      .LBB0_1                         ;     goto LBB0_1
        mov     rdx, qword ptr [rdi]            ; rdx = base pointer (first element)
        cmp     r9, 7                           ; if len > 7
        ja      .LBB0_5                         ;     goto LBB0_5
        xor     eax, eax                        ; eax = 0
        xor     esi, esi                        ; esi = 0
        jmp     .LBB0_4                         ; goto LBB0_4

.LBB0_1:
        xor     eax, eax                        ; return 0
        ret

Here, the function differentiates between 3 different "states":

  • Slice is empty → return 0 immediately
  • Slice length is ≤ 7 → use standard sequential algorithm (LBB0_4)
  • Slice length is > 7 → use SIMD algorithm (LBB0_5)

So let's take a look at the two different kinds of algorithms!


Standard sequential algorithm

Remember that rsi (esi) and rax (eax) were set to 0 and that rdx is the base pointer to the data.

.LBB0_4:
        mov     ecx, dword ptr [rdx + 4*rsi]    ; ecx = nums[rsi]
        add     rsi, 1                          ; rsi += 1
        mov     edi, ecx                        ; edi = ecx
        neg     edi                             ; edi = -edi
        cmovl   edi, ecx                        ; if ecx >= 0 { edi = ecx }
        add     eax, edi                        ; eax += edi
        cmp     r9, rsi                         ; if rsi != len
        jne     .LBB0_4                         ;     goto LBB0_4
        ret                                     ; return eax

This is a simple loop iterating over all elements of num. In the loop's body there is a little trick though: from the original element ecx, a negated value is stored in edi. By using cmovl, edi is overwritten with the original value if that original value is positive. That means that edi will always turn out positive (i.e. contain the absolute value of the original element). Then it is added to eax (which is returned in the end).

So your if branch was hidden in the cmov instruction. As you can see in this benchmark, the time required to execute a cmov instruction is independent of the probability of the condition. It's a pretty amazing instruction!


SIMD algorithm

The SIMD version consists of quite a few instructions that I won't fully paste here. The main loop handles 16 integers at once!

        movdqu  xmm5, xmmword ptr [rdx + 4*rdi]
        movdqu  xmm3, xmmword ptr [rdx + 4*rdi + 16]
        movdqu  xmm0, xmmword ptr [rdx + 4*rdi + 32]
        movdqu  xmm1, xmmword ptr [rdx + 4*rdi + 48]

They are loaded from memory into the registers xmm0, xmm1, xmm3 and xmm5. Each of those registers contains four 32 bit values, but to follow along more easily, just imagine each register contains exactly one value. All following instructions operate on each value of those SIMD registers individually, so that mental model is fine! My explanation below will also sound as if xmm registers would only contain a single value.

The main trick is now in the following instructions (which handle xmm5):

        movdqa  xmm6, xmm5      ; xmm6 = xmm5 (make a copy)
        psrad   xmm6, 31        ; logical right shift 31 bits (see below)
        paddd   xmm5, xmm6      ; xmm5 += xmm6
        pxor    xmm5, xmm6      ; xmm5 ^= xmm6

The logical right shift fills the "empty high-order bits" (the ones "shifted in" on the left) with the value of the sign bit. By shifting by 31, we end up with only the sign bit in every position! So any positive number will turn into 32 zeroes and any negative number will turn into 32 ones. So xmm6 is now either 000...000 (if xmm5 is positive) or 111...111 (if xmm5 is negative).

Next this artificial xmm6 is added to xmm5. If xmm5 was positive, xmm6 is 0, so adding it won't change xmm5. If xmm5 was negative, however, we add 111...111 which is equivalent to subtracting 1. Finally, we xor xmm5 with xmm6. Again, if xmm5 was positive in the beginning, we xor with 000...000 which does not have an effect. If xmm5 was negative in the beginning we xor with 111...111, meaning we flip all the bits. So for both cases:

  • If the element was positive, we change nothing (the add and xor didn't have any effect)
  • If the element was negative, we subtracted 1 and flipped all bits. This is a two's complement negation!

So with these 4 instructions we calculated the absolute value of xmm5! Here again, there is no branch because of this bit-fiddling trick. And remember that xmm5 actually contains 4 integers, so it's quite speedy!

This absolute value is now added to an accumulator and the same is done with the three other xmm registers that contain values from the slice. (We won't discuss the remaining code in detail.)


SIMD with AVX2

If we allow LLVM to emit AVX2 instructions (via -C target-feature=+avx2), it can even use the pabsd instruction instead of the four "hacky" instructions:

vpabsd  ymm2, ymmword ptr [rdx + 4*rdi]

It loads the values directly from memory, calculates the absolute and stores it in ymm2 in one instruction! And remember that ymm registers are twice as large as xmm registers (fitting eight 32 bit values)!

like image 159
Lukas Kalbertodt Avatar answered Sep 17 '22 09:09

Lukas Kalbertodt