Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala, tail recursion vs. non tail recursion, why is tail recursion slower?

I was explaining a friend that I expected non tail recursive function in Scala to be slower than tail recursive ones, so I decided to verify it. I wrote a good old factorial function both ways and attempted to compare the results. Here's the code:

def main(args: Array[String]): Unit = {
  val N = 2000 // not too much or else stackoverflows
  var spent1: Long = 0
  var spent2: Long = 0
  for ( i <- 1 to 100 ) { // repeat to average the results
    val t0 = System.nanoTime
    factorial(N)
    val t1 = System.nanoTime
    tailRecFact(N)
    val t2 = System.nanoTime
    spent1 += t1 - t0
    spent2 += t2 - t1
  }
  println(spent1/1000000f) // get milliseconds
  println(spent2/1000000f)
}

@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)

def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)

The results are confusing me, I get this kind of output:

578.2985

870.22125

Meaning the non tail recursive function is 30% faster than the tail recursive one, and the number of operation is the same!

What would explain those results?

like image 714
marc-antoine Avatar asked Oct 09 '13 09:10

marc-antoine


People also ask

Is tail recursion faster than non-tail?

As a rule of thumb; tail-recursive functions are faster if they don't need to reverse the result before returning it. That's because that requires another iteration over the whole list. Tail-recursive functions are usually faster at reducing lists, like our first example.

Why tail recursion is better than non-tail recursion?

The tail recursion is better than non-tail recursion. As there is no task left after the recursive call, it will be easier for the compiler to optimize the code. When one function is called, its address is stored inside the stack. So if it is tail recursion, then storing addresses into stack is not needed.

Why are recursive functions slower?

Recursion can be slow. It is actually pretty difficult to write a recursive function where the speed and memory will be less than that of an iterative function completing the same task. The reason that recursion is slow is that it requires the allocation of a new stack frame.

Why is recursion often slower than iteration?

Recursion has a large amount of overhead as compared to Iteration. It is usually much slower because all function calls must be stored in a stack to allow the return back to the caller functions. Iteration does not involve any such overhead.


2 Answers

It's actually not where you would first look.The reason is in your tail recursion method, you are doing more work with its multiply. Try swapping around the order of the params n and s in the recursive call and it will even out.

def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)

Moreover, most of the time in this sample is taken up with the BigInt operations which dwarf the time of the recursive call. If we switch these over to Ints (compiled to Java primitives) then you can see the how tail recursion (goto) compares to method invocation.

object Test extends App {

  val N = 2000

  val t0 = System.nanoTime()
  for ( i <- 1 to 1000 ) {
    factorial(N)
  }
  val t1 = System.nanoTime
  for ( i <- 1 to 1000 ) {
    tailRecFact(N, 1)
  }
  val t2 = System.nanoTime

  println((t1 - t0) / 1000000f) // get milliseconds
  println((t2 - t1) / 1000000f)

  def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)

  @tailrec
  final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
}

95.16733
3.987605

For interest, the decompiled output

  public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          13
      11: aload_2       
      12: areturn       
      13: aload_1       
      14: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      17: iconst_1      
      18: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      21: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      24: aload_1       
      25: aload_2       
      26: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      29: astore_2      
      30: astore_1      
      31: goto          0

  public scala.math.BigInt factorial(scala.math.BigInt);
    Code:
       0: aload_1       
       1: iconst_1      
       2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
       5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
       8: ifeq          21
      11: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      14: iconst_1      
      15: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      18: goto          40
      21: aload_1       
      22: aload_0       
      23: aload_1       
      24: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
      27: iconst_1      
      28: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
      31: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
      34: invokevirtual #47                 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
      37: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
      40: areturn   
like image 127
sksamuel Avatar answered Oct 22 '22 08:10

sksamuel


In addition to the problem shown by @monkjack (i.e multiplying small * big is faster than big * small, which does account for a greater chunk of the difference), your algorithm is different in each case so they're not really comparable.

In the tail-recursive version you're mutiplying big-to-small:

n * n-1 * n-2 * ... * 2 * 1

In the non-tail recursive version you're multiplying small-to-big:

n * (n-1 * (n-2 * (... * (2 * 1))))

If you alter the tail-recursive version so it multiplies small-to-big:

def tailRecFact2(n: BigInt) = {
  def loop(x: BigInt, out: BigInt): BigInt =
    if (x > n) out else loop(x + 1, x * out)
  loop(1, 1)
}

then tail-recursion is about 20% faster than normal-recursion, rather than 10% slower as it is if you just make monkjack's correction. This is because multiplying together small BigInts is faster than multiplying large ones.

like image 26
Luigi Plinge Avatar answered Oct 22 '22 08:10

Luigi Plinge