I was explaining a friend that I expected non tail recursive function in Scala to be slower than tail recursive ones, so I decided to verify it. I wrote a good old factorial function both ways and attempted to compare the results. Here's the code:
def main(args: Array[String]): Unit = {
val N = 2000 // not too much or else stackoverflows
var spent1: Long = 0
var spent2: Long = 0
for ( i <- 1 to 100 ) { // repeat to average the results
val t0 = System.nanoTime
factorial(N)
val t1 = System.nanoTime
tailRecFact(N)
val t2 = System.nanoTime
spent1 += t1 - t0
spent2 += t2 - t1
}
println(spent1/1000000f) // get milliseconds
println(spent2/1000000f)
}
@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)
def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)
The results are confusing me, I get this kind of output:
578.2985
870.22125
Meaning the non tail recursive function is 30% faster than the tail recursive one, and the number of operation is the same!
What would explain those results?
As a rule of thumb; tail-recursive functions are faster if they don't need to reverse the result before returning it. That's because that requires another iteration over the whole list. Tail-recursive functions are usually faster at reducing lists, like our first example.
The tail recursion is better than non-tail recursion. As there is no task left after the recursive call, it will be easier for the compiler to optimize the code. When one function is called, its address is stored inside the stack. So if it is tail recursion, then storing addresses into stack is not needed.
Recursion can be slow. It is actually pretty difficult to write a recursive function where the speed and memory will be less than that of an iterative function completing the same task. The reason that recursion is slow is that it requires the allocation of a new stack frame.
Recursion has a large amount of overhead as compared to Iteration. It is usually much slower because all function calls must be stored in a stack to allow the return back to the caller functions. Iteration does not involve any such overhead.
It's actually not where you would first look.The reason is in your tail recursion method, you are doing more work with its multiply. Try swapping around the order of the params n and s in the recursive call and it will even out.
def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)
Moreover, most of the time in this sample is taken up with the BigInt operations which dwarf the time of the recursive call. If we switch these over to Ints (compiled to Java primitives) then you can see the how tail recursion (goto) compares to method invocation.
object Test extends App {
val N = 2000
val t0 = System.nanoTime()
for ( i <- 1 to 1000 ) {
factorial(N)
}
val t1 = System.nanoTime
for ( i <- 1 to 1000 ) {
tailRecFact(N, 1)
}
val t2 = System.nanoTime
println((t1 - t0) / 1000000f) // get milliseconds
println((t2 - t1) / 1000000f)
def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)
@tailrec
final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}
95.16733
3.987605
For interest, the decompiled output
public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
Code:
0: aload_1
1: iconst_1
2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
8: ifeq 13
11: aload_2
12: areturn
13: aload_1
14: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
17: iconst_1
18: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
21: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
24: aload_1
25: aload_2
26: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
29: astore_2
30: astore_1
31: goto 0
public scala.math.BigInt factorial(scala.math.BigInt);
Code:
0: aload_1
1: iconst_1
2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
8: ifeq 21
11: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
14: iconst_1
15: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
18: goto 40
21: aload_1
22: aload_0
23: aload_1
24: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
27: iconst_1
28: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
31: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
34: invokevirtual #47 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
37: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
40: areturn
In addition to the problem shown by @monkjack (i.e multiplying small * big is faster than big * small, which does account for a greater chunk of the difference), your algorithm is different in each case so they're not really comparable.
In the tail-recursive version you're mutiplying big-to-small:
n * n-1 * n-2 * ... * 2 * 1
In the non-tail recursive version you're multiplying small-to-big:
n * (n-1 * (n-2 * (... * (2 * 1))))
If you alter the tail-recursive version so it multiplies small-to-big:
def tailRecFact2(n: BigInt) = {
def loop(x: BigInt, out: BigInt): BigInt =
if (x > n) out else loop(x + 1, x * out)
loop(1, 1)
}
then tail-recursion is about 20% faster than normal-recursion, rather than 10% slower as it is if you just make monkjack's correction. This is because multiplying together small BigInts is faster than multiplying large ones.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With