Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is Clojure much faster than Scala on a recursive add function?

Tags:

A friend gave me this code snippet in Clojure

(defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc)))) (time (sum (range 1 9999999) 0)) 

and asked me how does it fare against a similar Scala implementation.

The Scala code I've written looks like this:

def from(n: Int): Stream[Int] = Stream.cons(n, from(n+1)) val ints = from(1).take(9999998)  def add(a: Stream[Int], b: Long): Long = {     if (a.isEmpty) b else add(a.tail, b + a.head) }  val t1 = System.currentTimeMillis() println(add(ints, 0)) val t2 = System.currentTimeMillis() println((t2 - t1).asInstanceOf[Float] + " msecs") 

Bottom line is: the code in Clojure runs in about 1.8 seconds on my machine and uses less than 5MB of heap, the code in Scala runs in about 12 seconds and 512MB of heap aren't enough (it finishes the computation if I set the heap to 1GB).

So I'm wondering why is Clojure so much faster and slimmer in this particular case? Do you have a Scala implementation that has a similar behavior in terms of speed and memory usage?

Please refrain from religious remarks, my interest lies in finding out primarily what makes clojure so fast in this case and if there's a faster implementation of the algo in scala. Thanks.

like image 883
Li Lo Avatar asked Aug 31 '09 19:08

Li Lo


2 Answers

First, Scala only optimises tail calls if you invoke it with -optimise. Edit: It seems Scala will always optimise tail-call recursions if it can, even without -optimise.

Second, Stream and Range are two very different things. A Range has a beginning and an end, and its projection has just a counter and the end. A Stream is a list which will be computed on-demand. Since you are adding the whole ints, you'll compute, and, therefore, allocate, the whole Stream.

A closer code would be:

import scala.annotation.tailrec  def add(r: Range) = {   @tailrec    def f(i: Iterator[Int], acc: Long): Long =      if (i.hasNext) f(i, acc + i.next) else acc    f(r iterator, 0) }  def time(f: => Unit) {   val t1 = System.currentTimeMillis()   f   val t2 = System.currentTimeMillis()   println((t2 - t1).asInstanceOf[Float]+" msecs") } 

Normal run:

scala> time(println(add(1 to 9999999))) 49999995000000 563.0 msecs 

On Scala 2.7 you need "elements" instead of "iterator", and there's no "tailrec" annotation -- that annotation is used just to complain if a definition can't be optimized with tail recursion -- so you'll need to strip "@tailrec" as well as the "import scala.annotation.tailrec" from the code.

Also, some considerations on alternate implementations. The simplest:

scala> time(println(1 to 9999999 reduceLeft (_+_))) -2014260032 640.0 msecs 

On average, with multiple runs here, it is slower. It's also incorrect, because it works just with Int. A correct one:

scala> time(println((1 to 9999999 foldLeft 0L)(_+_))) 49999995000000 797.0 msecs 

That's slower still, running here. I honestly wouldn't have expected it to run slower, but each interation calls to the function being passed. Once you consider that, it's a pretty good time compared to the recursive version.

like image 60
Daniel C. Sobral Avatar answered Oct 29 '22 16:10

Daniel C. Sobral


Clojure's range does not memoize, Scala's Stream does. Totally different data structures with totally different results. Scala does have a non memoizing Range structure, but it's currently kind of awkard to work with in this simple recursive way. Here's my take on the whole thing.

Using Clojure 1.0 on an older box, which is slow, I get 3.6 seconds

user=> (defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc)))) #'user/sum user=> (time (sum (range 1 9999999) 0)) "Elapsed time: 3651.751139 msecs" 49999985000001 

A literal translation to Scala requires me to write some code

def time[T](x : => T) =  {   val start = System.nanoTime : Double   val result = x   val duration = (System.nanoTime : Double) - start   println("Elapsed time " + duration / 1000000.0 + " msecs")   result } 

It's good to make sure that that's right

scala> time (Thread sleep 1000) Elapsed time 1000.277967 msecs 

Now we need an unmemoized Range with similar semantics to Clojure's

case class MyRange(start : Int, end : Int) {   def isEmpty = start >= end   def first = if (!isEmpty) start else error("empty range")   def rest = new MyRange(start + 1, end) } 

From that "add" follows directly

def add(a: MyRange, b: Long): Long = {     if (a.isEmpty) b else add(a.rest, b + a.first) } 

And it times much faster than Clojure's on the same box

scala> time(add(MyRange(1, 9999999), 0)) Elapsed time 252.526784 msecs res1: Long = 49999985000001 

Using Scala's standard library Range, you can do a fold. It's not as fast as simple primitive recursion, but its less code and still faster than the Clojure recursive version (at least on my box).

scala> time((1 until 9999999 foldLeft 0L)(_ + _)) Elapsed time 1995.566127 msecs res2: Long = 49999985000001 

Contrast with a fold over a memoized Stream

time((Stream from 1 take 9999998 foldLeft 0L)(_ + _))  Elapsed time 3879.991318 msecs res3: Long = 49999985000001 
like image 34
James Iry Avatar answered Oct 29 '22 16:10

James Iry