Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can compilers (specifically rustc) really simplify triangle-summation to avoid a loop? How?

On page 322 of Programming Rust by Blandy and Orendorff is this claim:

...Rust...recognizes that there's a simpler way to sum the numbers from one to n: the sum is always equal to n * (n+1) / 2.

This is of course a fairly well-known equivalence, but how does the compiler recognize it? I'm guessing it's in an LLVM optimization pass, but is LLVM somehow deriving the equivalence from first principles, or does it just have some set of "common loop computations" that can be simplified to arithmetic operations?

like image 932
Kyle Strand Avatar asked Apr 22 '19 20:04

Kyle Strand


1 Answers

First of all, let's demonstrate that this actually happens.

Starting with this code:

pub fn sum(start: i32, end: i32) -> i32 {
    let mut result = 0;
    for i in start..end {
        result += i;
    }
    return result;
}

And compiling in Release, we get:

; playground::sum
; Function Attrs: nounwind nonlazybind readnone uwtable
define i32 @_ZN10playground3sum17h41f12649b0533596E(i32 %start1, i32 %end) {
start:
    %0 = icmp slt i32 %start1, %end
    br i1 %0, label %bb5.preheader, label %bb6

bb5.preheader:                                    ; preds = %start
    %1 = xor i32 %start1, -1
    %2 = add i32 %1, %end
    %3 = add i32 %start1, 1
    %4 = mul i32 %2, %3
    %5 = zext i32 %2 to i33
    %6 = add i32 %end, -2
    %7 = sub i32 %6, %start1
    %8 = zext i32 %7 to i33
    %9 = mul i33 %5, %8
    %10 = lshr i33 %9, 1
    %11 = trunc i33 %10 to i32
    %12 = add i32 %4, %start1
    %13 = add i32 %12, %11
    br label %bb6

bb6:                                              ; preds = %bb5.preheader, %start
    %result.0.lcssa = phi i32 [ 0, %start ], [ %13, %bb5.preheader ]
    ret i32 %result.0.lcssa
}

Where we can indeed observe that there is no loop any longer.

Thus we validate the claim by Bandy and Orendorff.


As for how this occurs, my understanding is that this all happens in ScalarEvolution.cpp in LLVM. Unfortunately, that file is a 12,000+ lines monstruosity, so navigating it is a tad complicated; still, the head comment hints that we should be in the right place, and points to the papers it used which mention optimizing loops and closed-form functions1:

 //===----------------------------------------------------------------------===//
 //
 // There are several good references for the techniques used in this analysis.
 //
 //  Chains of recurrences -- a method to expedite the evaluation
 //  of closed-form functions
 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
 //
 //  On computational properties of chains of recurrences
 //  Eugene V. Zima
 //
 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
 //  Robert A. van Engelen
 //
 //  Efficient Symbolic Analysis for Optimizing Compilers
 //  Robert A. van Engelen
 //
 //  Using the chains of recurrences algebra for data dependence testing and
 //  induction variable substitution
 //  MS Thesis, Johnie Birch
 //
 //===----------------------------------------------------------------------===//

According to this blog article by Krister Walfridsson, it builds up chains of recurrences, which can be used to obtain a closed-form formula for each inductive variable.

This is a mid-point between full reasoning and full hardcoding:

  • Pattern-matching is used to build the chains of recurrence, so LLVM may not recognize all ways of expressing a certain computation.
  • A large variety of formulas can be optimized, not only the triangle sum.

The article also notes that the optimization may end up pessimizing the code: a small number of iterations can be faster if the "optimized" code requires a larger number of operations compared to the inner body of the loop.

1n * (n+1) / 2 is the closed-form function to compute the sum of numbers in [0, n].

like image 191
Matthieu M. Avatar answered Oct 18 '22 04:10

Matthieu M.