A friend gave me this code snippet in Clojure
(defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc)))) (time (sum (range 1 9999999) 0))
and asked me how does it fare against a similar Scala implementation.
The Scala code I've written looks like this:
def from(n: Int): Stream[Int] = Stream.cons(n, from(n+1)) val ints = from(1).take(9999998) def add(a: Stream[Int], b: Long): Long = { if (a.isEmpty) b else add(a.tail, b + a.head) } val t1 = System.currentTimeMillis() println(add(ints, 0)) val t2 = System.currentTimeMillis() println((t2 - t1).asInstanceOf[Float] + " msecs")
Bottom line is: the code in Clojure runs in about 1.8 seconds on my machine and uses less than 5MB of heap, the code in Scala runs in about 12 seconds and 512MB of heap aren't enough (it finishes the computation if I set the heap to 1GB).
So I'm wondering why is Clojure so much faster and slimmer in this particular case? Do you have a Scala implementation that has a similar behavior in terms of speed and memory usage?
Please refrain from religious remarks, my interest lies in finding out primarily what makes clojure so fast in this case and if there's a faster implementation of the algo in scala. Thanks.
First, Scala only optimises tail calls if you invoke it with -optimise
. Edit: It seems Scala will always optimise tail-call recursions if it can, even without -optimise
.
Second, Stream
and Range
are two very different things. A Range
has a beginning and an end, and its projection has just a counter and the end. A Stream
is a list which will be computed on-demand. Since you are adding the whole ints
, you'll compute, and, therefore, allocate, the whole Stream
.
A closer code would be:
import scala.annotation.tailrec def add(r: Range) = { @tailrec def f(i: Iterator[Int], acc: Long): Long = if (i.hasNext) f(i, acc + i.next) else acc f(r iterator, 0) } def time(f: => Unit) { val t1 = System.currentTimeMillis() f val t2 = System.currentTimeMillis() println((t2 - t1).asInstanceOf[Float]+" msecs") }
Normal run:
scala> time(println(add(1 to 9999999))) 49999995000000 563.0 msecs
On Scala 2.7 you need "elements
" instead of "iterator
", and there's no "tailrec
" annotation -- that annotation is used just to complain if a definition can't be optimized with tail recursion -- so you'll need to strip "@tailrec
" as well as the "import scala.annotation.tailrec
" from the code.
Also, some considerations on alternate implementations. The simplest:
scala> time(println(1 to 9999999 reduceLeft (_+_))) -2014260032 640.0 msecs
On average, with multiple runs here, it is slower. It's also incorrect, because it works just with Int. A correct one:
scala> time(println((1 to 9999999 foldLeft 0L)(_+_))) 49999995000000 797.0 msecs
That's slower still, running here. I honestly wouldn't have expected it to run slower, but each interation calls to the function being passed. Once you consider that, it's a pretty good time compared to the recursive version.
Clojure's range does not memoize, Scala's Stream does. Totally different data structures with totally different results. Scala does have a non memoizing Range structure, but it's currently kind of awkard to work with in this simple recursive way. Here's my take on the whole thing.
Using Clojure 1.0 on an older box, which is slow, I get 3.6 seconds
user=> (defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc)))) #'user/sum user=> (time (sum (range 1 9999999) 0)) "Elapsed time: 3651.751139 msecs" 49999985000001
A literal translation to Scala requires me to write some code
def time[T](x : => T) = { val start = System.nanoTime : Double val result = x val duration = (System.nanoTime : Double) - start println("Elapsed time " + duration / 1000000.0 + " msecs") result }
It's good to make sure that that's right
scala> time (Thread sleep 1000) Elapsed time 1000.277967 msecs
Now we need an unmemoized Range with similar semantics to Clojure's
case class MyRange(start : Int, end : Int) { def isEmpty = start >= end def first = if (!isEmpty) start else error("empty range") def rest = new MyRange(start + 1, end) }
From that "add" follows directly
def add(a: MyRange, b: Long): Long = { if (a.isEmpty) b else add(a.rest, b + a.first) }
And it times much faster than Clojure's on the same box
scala> time(add(MyRange(1, 9999999), 0)) Elapsed time 252.526784 msecs res1: Long = 49999985000001
Using Scala's standard library Range, you can do a fold. It's not as fast as simple primitive recursion, but its less code and still faster than the Clojure recursive version (at least on my box).
scala> time((1 until 9999999 foldLeft 0L)(_ + _)) Elapsed time 1995.566127 msecs res2: Long = 49999985000001
Contrast with a fold over a memoized Stream
time((Stream from 1 take 9999998 foldLeft 0L)(_ + _)) Elapsed time 3879.991318 msecs res3: Long = 49999985000001
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With