Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

For comprehension and number of function creation

Recently I had an interview for Scala Developer position. I was asked such question

// matrix 100x100 (content unimportant)

val matrix = Seq.tabulate(100, 100) { case (x, y) => x + y }

// A

for {

   row <- matrix

   elem <- row

} print(elem)

// B

val func = print _
for {

   row <- matrix

   elem <- row

} func(elem)

and the question was: Which implementation, A or B, is more efficent?

We all know that for comprehensions can be translated to

// A

matrix.foreach(row => row.foreach(elem => print(elem)))

// B

matrix.foreach(row => row.foreach(func))

B can be written as matrix.foreach(row => row.foreach(print _))

Supposedly correct answer is B, because A will create function print 100 times more.

I have checked Language Specification but still fail to understand the answer. Can somebody explain this to me?

like image 340
goral Avatar asked May 19 '14 11:05

goral


3 Answers

In short:

Example A is faster in theory, in practice you shouldn't be able to measure any difference though.

Long answer:

As you already found out

for {xs <- xxs; x <- xs} f(x)

is translated to

xxs.foreach(xs => xs.foreach(x => f(x)))

This is explained in §6.19 SLS:

A for loop

for ( p <- e; p' <- e' ... ) e''

where ... is a (possibly empty) sequence of generators, definitions, or guards, is translated to

e .foreach { case p => for ( p' <- e' ... ) e'' }

Now when one writes a function literal, one gets a new instance every time the function needs to be called (§6.23 SLS). This means that

xs.foreach(x => f(x))

is equivalent to

xs.foreach(new scala.Function1 { def apply(x: T) = f(x)})

When you introduce a local function type

val g = f _; xxs.foreach(xs => xs.foreach(x => g(x)))

you are not introducing an optimization because you still pass a function literal to foreach. In fact the code is slower because the inner foreach is translated to

xs.foreach(new scala.Function1 { def apply(x: T) = g.apply(x) })

where an additional call to the apply method of g happens. Though, you can optimize when you write

val g = f _; xxs.foreach(xs => xs.foreach(g))

because the inner foreach now is translated to

xs.foreach(g())

which means that the function g itself is passed to foreach.

This would mean that B is faster in theory, because no anonymous function needs to be created each time the body of the for comprehension is executed. However, the optimization mentioned above (that the function is directly passed to foreach) is not applied on for comprehensions, because as the spec says the translation includes the creation of function literals, therefore there are always unnecessary function objects created (here I must say that the compiler could optimize that as well, but it doesn't because optimization of for comprehensions is difficult and does still not happen in 2.11). All in all it means that A is more efficient but B would be more efficient if it is written without a for comprehension (and no function literal is created for the innermost function).

Nevertheless, all of these rules can only be applied in theory, because in practice there is the backend of scalac and the JVM itself which both can do optimizations - not to mention optimizations that are done by the CPU. Furthermore your example contains a syscall that is executed on every iteration - it is probably the most expensive operation here that outweighs everything else.

like image 74
kiritsuku Avatar answered Nov 12 '22 14:11

kiritsuku


I'd agree with sschaef and say that A is the more efficient option.

Looking at the generated class files we get the following anonymous functions and their apply methods:

MethodA:
  anonfun$2            -- row => row.foreach(new anonfun$2$$anonfun$1)
  anonfun$2$$anonfun$1 -- elem => print(elem)

i.e. matrix.foreach(row => row.foreach(elem => print(elem)))

MethodB:
  anonfun$3            -- x => print(x)
  anonfun$4            -- row => row.foreach(new anonfun$4$$anonfun$2)
  anonfun$4$$anonfun$2 -- elem => func(elem)

i.e. matrix.foreach(row => row.foreach(elem => func(elem))) where func is just another indirection before calling to print. In addition func needs to be looked up, i.e. through a method call on an instance (this.func()) for each row.

So for Method B, 1 extra object is created (func) and there are # of elem additional function calls.

The most efficient option would be

matrix.foreach(row => row.foreach(func))

as this has the least number of objects created and does exactly as you would expect.

like image 26
ggovan Avatar answered Nov 12 '22 12:11

ggovan


Benchmark

Summary

Method A is nearly 30% faster than method B.

Link to code: https://gist.github.com/ziggystar/490f693bc39d1396ef8d

Implementation Details

I added method C (two while loops) and D (fold, sum). I also increased the size of the matrix and used an IndexedSeq instead. Also I replaced the print with something less heavy (sum all entries).

Strangely the while construct is not the fastest. But if one uses Array instead of IndexedSeq it becomes the fastest by a large margin (factor 5, no boxing anymore). Using explicitly boxed integers, methods A, B, C are all equally fast. In particular they are faster by 50% compared to the implicitly boxed versions of A, B.

Results

A
4.907797735
4.369745787
4.375195012000001
4.7421321800000005
4.35150636
B
5.955951859000001
5.925475619
5.939570085000001
5.955592247
5.939672226000001
C
5.991946029
5.960122757000001
5.970733164
6.025532582
6.04999499
D
9.278486201
9.265983922
9.228320372
9.255641645
9.22281905
verify results
999000000
999000000
999000000
999000000

>$ scala -version
Scala code runner version 2.11.0 -- Copyright 2002-2013, LAMP/EPFL

Code excerpt

val matrix = IndexedSeq.tabulate(1000, 1000) { case (x, y) => x + y }

def variantA(): Int = {
  var r = 0
  for {
    row <- matrix
    elem <- row
  }{
    r += elem
  }
  r
}

def variantB(): Int = {
  var r = 0
  val f = (x:Int) => r += x
  for {
    row <- matrix
    elem <- row
  } f(elem)
  r
}

def variantC(): Int = {
  var r = 0
  var i1 = 0
  while(i1 < matrix.size){
    var i2 = 0
    val row = matrix(i1)
    while(i2 < row.size){
      r += row(i2)
      i2 += 1
    }
    i1 += 1
  }
  r
}

def variantD(): Int = matrix.foldLeft(0)(_ + _.sum)
like image 2
ziggystar Avatar answered Nov 12 '22 13:11

ziggystar