Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Suggestions to optimize a simple Scala foldLeft over multiple values?

I'm re-implementing some code (a simple Bayesian inference algorithm, but that's not really important) from Java to Scala. I'd like to implement it in the most performant way possible, while keeping the code clean and functional by avoiding mutability as much as possible.

Here is the snippet of the Java code:

    // initialize
    double lP  = Math.log(prior);
    double lPC = Math.log(1-prior);

    // accumulate probabilities from each annotation object into lP and lPC
    for (Annotation annotation : annotations) {
        float prob = annotation.getProbability();
        if (isValidProbability(prob)) {
            lP  += logProb(prob);
            lPC += logProb(1 - prob);
        }
    } 

Pretty simple, right? So I decided to use Scala foldLeft and map methods for my first try. Since I have two values I'm accumulating over, the accumulator is a tuple:

    val initial  = (math.log(prior), math.log(1-prior))
    val probs    = annotations map (_.getProbability)
    val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

Unfortunately, this code performs about 5 times slower than Java (using a simple and imprecise metric; just called the code 10000 times in a loop). One defect is pretty clear; we are traversing lists twice, once in the call to map and the other in the foldLeft. So here's a version that traverses the list once.

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

This is better! It performs about 3 times worse than the Java code. My next hunch was that there is probably some cost involved in creating all the new tuples in each step of the fold. So I decided to try a version that traverses the list twice, but without creating tuples.

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
       val  p = annotation.getProbability
       if(isValidProbability(p)) r + logProb(p) else r
    })
    val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) r + logProb(1-p) else r
    })

This performs about the same as the previous version (3 times slower than the Java version). Not really surprising, but I was hopeful.

So my question is, is there a faster way to implement this Java snippet in Scala, while keeping the Scala code clean, avoiding unnecessary mutability and following Scala idioms? I do expect to use this code eventually in a concurrent environment, so the value of keeping immutability may outweigh the slower performance in a single thread.

like image 978
Raj B Avatar asked Feb 02 '12 17:02

Raj B


4 Answers

First, some of your penalty may come from the type of collection you're using. But most of it is probably the object creation which you actually do not avoid by running the loop twice, since the numbers have to be boxed.

Instead, you can create a mutable class that accumulates the values for you:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
  def *=(p: Double) = {
    if (isValidProbability(p)) {
      lp += logProb(p)
      lpc += logProb(1-p)
    }
    this  // Pass self on so we can fold over the operation
  }
  def toTuple = (lp, lpc)
}

Now although you can use this unsafely, you don't have to. In fact, you can just fold over it.

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

If you use this pattern, all the mutable unsafety is tucked away inside the fold; it never escapes.

Now, you can't do a parallel fold, but you can do an aggregate, which is like a fold with an extra operation to combine pieces. So you add the method

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

to LogOdds and then

annotations.aggregate(new LogOdds())(
  (r,ann) => r *= ann.getProbability,
  (l,r) => l**r
).toTuple

and you'll be good to go.

(Feel free to use non-mathematical symbols for this, but since you're basically multiplying probabilities, the multiplication symbol seemed more likely to give an intuitive idea for what is going on than incorporateProbability or somesuch.)

like image 142
Rex Kerr Avatar answered Oct 21 '22 15:10

Rex Kerr


You could implement a tail-recursive method which will be converted to a while-loop by the compiler, hence should be as fast as the Java version. Or, you could just use a loop - there's no law against it, if it just uses local variables in a method (see extensive use in the Scala collections source code, for example).

def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
  if (lst.isEmpty) (lP, lPC)
  else {
    val prob = lst.head.getProbability
    if (isValidProbability(prob)) 
      calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
    else 
      calc(lst.tail, lP, lPC)
  }
}

The advantage of folding is that it's parallelizable, which may lead to it being faster than the Java version on a multi-core machine (see other answers).

like image 40
Luigi Plinge Avatar answered Oct 21 '22 17:10

Luigi Plinge


As a kind of side note: you can avoid traversing the list twice more idiomatically by using view:

val probs = annotations.view.map(_.getProbability).filter(isValidProbability)

val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
   case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}

This probably isn't going to get you better performance than your third version, but it feels more elegant to me.

like image 37
Travis Brown Avatar answered Oct 21 '22 17:10

Travis Brown


First, let's address the performance issue: there's no way to implement it as fast as Java except by using while loops. Basically, JVM cannot optimize the Scala loop to the extent it optimizes the Java one. The reasons for that are even a concern among the JVM folk because it gets in the way of they parallel library efforts as well.

Now, back to Scala performance, you can also use .view to avoid creating a new collection in the map step, but I think the map step will always lead to worse performance. The thing is, you are converting the collection into one parameterized on Double, which must be boxed and unboxed.

However, there's one possible way of optimizing it: making it parallel. If you call .par on annotations to make it a parallel collection, you can then use fold:

val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
   if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
  if(isValidProbability(p)) r + logProb(1-p) else r
})

To avoid a separate map step, use aggregate instead of fold, as suggested by Rex.

For bonus points, you could use Future to make both computations run in parallel. I suspect you'll get better performance by bringing the tuples back and running it in one go, though. You'll have to benchmark this stuff to see what works better.

On parallel colletions, it might pay off to first filter it for valid annotations. Or, perhaps, collect.

val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;

or

val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }

Anyway, benchmark.

like image 43
Daniel C. Sobral Avatar answered Oct 21 '22 17:10

Daniel C. Sobral