I've written this very simple Rust function:
fn iterate(nums: &Box<[i32]>) -> i32 {
let mut total = 0;
let len = nums.len();
for i in 0..len {
if nums[i] > 0 {
total += nums[i];
} else {
total -= nums[i];
}
}
total
}
I've written a basic benchmark that invokes the method with an ordered array and a shuffled one:
fn criterion_benchmark(c: &mut Criterion) {
const SIZE: i32 = 1024 * 1024;
let mut group = c.benchmark_group("Branch Prediction");
// setup benchmarking for an ordered array
let mut ordered_nums: Vec<i32> = vec![];
for i in 0..SIZE {
ordered_nums.push(i - SIZE/2);
}
let ordered_nums = ordered_nums.into_boxed_slice();
group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));
// setup benchmarking for a shuffled array
let mut shuffled_nums: Vec<i32> = vec![];
for i in 0..SIZE {
shuffled_nums.push(i - SIZE/2);
}
let mut rng = thread_rng();
let mut shuffled_nums = shuffled_nums.into_boxed_slice();
shuffled_nums.shuffle(&mut rng);
group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
I'm surprised that the two benchmarks have almost exactly the same runtime, while a similar benchmark in Java shows a distinct difference between the two, presumably due to branch prediction failure in the shuffled case.
I've seen mention of conditional move instructions, but if I otool -tv
the executable (I'm running on a Mac), I don't see any in the iterate
method output.
Can anyone shed light on why there's no perceptible performance difference between the ordered and the unordered cases in Rust?
Summary: LLVM was able to remove/hide the branch by using either the cmov
instruction or a really clever combination of SIMD instructions.
I used Godbolt to view the full assembly (with -C opt-level=3
). I will explain the important parts of the assembly below.
It starts like this:
mov r9, qword ptr [rdi + 8] ; r9 = nums.len()
test r9, r9 ; if len == 0
je .LBB0_1 ; goto LBB0_1
mov rdx, qword ptr [rdi] ; rdx = base pointer (first element)
cmp r9, 7 ; if len > 7
ja .LBB0_5 ; goto LBB0_5
xor eax, eax ; eax = 0
xor esi, esi ; esi = 0
jmp .LBB0_4 ; goto LBB0_4
.LBB0_1:
xor eax, eax ; return 0
ret
Here, the function differentiates between 3 different "states":
LBB0_4
)LBB0_5
)So let's take a look at the two different kinds of algorithms!
Remember that rsi
(esi
) and rax
(eax
) were set to 0 and that rdx
is the base pointer to the data.
.LBB0_4:
mov ecx, dword ptr [rdx + 4*rsi] ; ecx = nums[rsi]
add rsi, 1 ; rsi += 1
mov edi, ecx ; edi = ecx
neg edi ; edi = -edi
cmovl edi, ecx ; if ecx >= 0 { edi = ecx }
add eax, edi ; eax += edi
cmp r9, rsi ; if rsi != len
jne .LBB0_4 ; goto LBB0_4
ret ; return eax
This is a simple loop iterating over all elements of num
. In the loop's body there is a little trick though: from the original element ecx
, a negated value is stored in edi
. By using cmovl
, edi
is overwritten with the original value if that original value is positive. That means that edi
will always turn out positive (i.e. contain the absolute value of the original element). Then it is added to eax
(which is returned in the end).
So your if
branch was hidden in the cmov
instruction. As you can see in this benchmark, the time required to execute a cmov
instruction is independent of the probability of the condition. It's a pretty amazing instruction!
The SIMD version consists of quite a few instructions that I won't fully paste here. The main loop handles 16 integers at once!
movdqu xmm5, xmmword ptr [rdx + 4*rdi]
movdqu xmm3, xmmword ptr [rdx + 4*rdi + 16]
movdqu xmm0, xmmword ptr [rdx + 4*rdi + 32]
movdqu xmm1, xmmword ptr [rdx + 4*rdi + 48]
They are loaded from memory into the registers xmm0
, xmm1
, xmm3
and xmm5
. Each of those registers contains four 32 bit values, but to follow along more easily, just imagine each register contains exactly one value. All following instructions operate on each value of those SIMD registers individually, so that mental model is fine! My explanation below will also sound as if xmm
registers would only contain a single value.
The main trick is now in the following instructions (which handle xmm5
):
movdqa xmm6, xmm5 ; xmm6 = xmm5 (make a copy)
psrad xmm6, 31 ; logical right shift 31 bits (see below)
paddd xmm5, xmm6 ; xmm5 += xmm6
pxor xmm5, xmm6 ; xmm5 ^= xmm6
The logical right shift fills the "empty high-order bits" (the ones "shifted in" on the left) with the value of the sign bit. By shifting by 31, we end up with only the sign bit in every position! So any positive number will turn into 32 zeroes and any negative number will turn into 32 ones. So xmm6
is now either 000...000
(if xmm5
is positive) or 111...111
(if xmm5
is negative).
Next this artificial xmm6
is added to xmm5
. If xmm5
was positive, xmm6
is 0, so adding it won't change xmm5
. If xmm5
was negative, however, we add 111...111
which is equivalent to subtracting 1. Finally, we xor xmm5
with xmm6
. Again, if xmm5
was positive in the beginning, we xor with 000...000
which does not have an effect. If xmm5
was negative in the beginning we xor with 111...111
, meaning we flip all the bits. So for both cases:
add
and xor
didn't have any effect)So with these 4 instructions we calculated the absolute value of xmm5
! Here again, there is no branch because of this bit-fiddling trick. And remember that xmm5
actually contains 4 integers, so it's quite speedy!
This absolute value is now added to an accumulator and the same is done with the three other xmm
registers that contain values from the slice. (We won't discuss the remaining code in detail.)
If we allow LLVM to emit AVX2 instructions (via -C target-feature=+avx2
), it can even use the pabsd
instruction instead of the four "hacky" instructions:
vpabsd ymm2, ymmword ptr [rdx + 4*rdi]
It loads the values directly from memory, calculates the absolute and stores it in ymm2
in one instruction! And remember that ymm
registers are twice as large as xmm
registers (fitting eight 32 bit values)!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With