I'm re-implementing some code (a simple Bayesian inference algorithm, but that's not really important) from Java to Scala. I'd like to implement it in the most performant way possible, while keeping the code clean and functional by avoiding mutability as much as possible.
Here is the snippet of the Java code:
// initialize
double lP = Math.log(prior);
double lPC = Math.log(1-prior);
// accumulate probabilities from each annotation object into lP and lPC
for (Annotation annotation : annotations) {
float prob = annotation.getProbability();
if (isValidProbability(prob)) {
lP += logProb(prob);
lPC += logProb(1 - prob);
}
}
Pretty simple, right? So I decided to use Scala foldLeft and map methods for my first try. Since I have two values I'm accumulating over, the accumulator is a tuple:
val initial = (math.log(prior), math.log(1-prior))
val probs = annotations map (_.getProbability)
val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
Unfortunately, this code performs about 5 times slower than Java (using a simple and imprecise metric; just called the code 10000 times in a loop). One defect is pretty clear; we are traversing lists twice, once in the call to map and the other in the foldLeft. So here's a version that traverses the list once.
val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})
This is better! It performs about 3 times worse than the Java code. My next hunch was that there is probably some cost involved in creating all the new tuples in each step of the fold. So I decided to try a version that traverses the list twice, but without creating tuples.
val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(1-p) else r
})
This performs about the same as the previous version (3 times slower than the Java version). Not really surprising, but I was hopeful.
So my question is, is there a faster way to implement this Java snippet in Scala, while keeping the Scala code clean, avoiding unnecessary mutability and following Scala idioms? I do expect to use this code eventually in a concurrent environment, so the value of keeping immutability may outweigh the slower performance in a single thread.
First, some of your penalty may come from the type of collection you're using. But most of it is probably the object creation which you actually do not avoid by running the loop twice, since the numbers have to be boxed.
Instead, you can create a mutable class that accumulates the values for you:
class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
def *=(p: Double) = {
if (isValidProbability(p)) {
lp += logProb(p)
lpc += logProb(1-p)
}
this // Pass self on so we can fold over the operation
}
def toTuple = (lp, lpc)
}
Now although you can use this unsafely, you don't have to. In fact, you can just fold over it.
annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple
If you use this pattern, all the mutable unsafety is tucked away inside the fold; it never escapes.
Now, you can't do a parallel fold, but you can do an aggregate, which is like a fold with an extra operation to combine pieces. So you add the method
def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)
to LogOdds
and then
annotations.aggregate(new LogOdds())(
(r,ann) => r *= ann.getProbability,
(l,r) => l**r
).toTuple
and you'll be good to go.
(Feel free to use non-mathematical symbols for this, but since you're basically multiplying probabilities, the multiplication symbol seemed more likely to give an intuitive idea for what is going on than incorporateProbability or somesuch.)
You could implement a tail-recursive method which will be converted to a while-loop by the compiler, hence should be as fast as the Java version. Or, you could just use a loop - there's no law against it, if it just uses local variables in a method (see extensive use in the Scala collections source code, for example).
def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
if (lst.isEmpty) (lP, lPC)
else {
val prob = lst.head.getProbability
if (isValidProbability(prob))
calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
else
calc(lst.tail, lP, lPC)
}
}
The advantage of folding is that it's parallelizable, which may lead to it being faster than the Java version on a multi-core machine (see other answers).
As a kind of side note: you can avoid traversing the list twice more idiomatically by using view
:
val probs = annotations.view.map(_.getProbability).filter(isValidProbability)
val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}
This probably isn't going to get you better performance than your third version, but it feels more elegant to me.
First, let's address the performance issue: there's no way to implement it as fast as Java except by using while loops. Basically, JVM cannot optimize the Scala loop to the extent it optimizes the Java one. The reasons for that are even a concern among the JVM folk because it gets in the way of they parallel library efforts as well.
Now, back to Scala performance, you can also use .view
to avoid creating a new collection in the map
step, but I think the map
step will always lead to worse performance. The thing is, you are converting the collection into one parameterized on Double
, which must be boxed and unboxed.
However, there's one possible way of optimizing it: making it parallel. If you call .par
on annotations
to make it a parallel collection, you can then use fold
:
val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
if(isValidProbability(p)) r + logProb(1-p) else r
})
To avoid a separate map
step, use aggregate
instead of fold
, as suggested by Rex.
For bonus points, you could use Future
to make both computations run in parallel. I suspect you'll get better performance by bringing the tuples back and running it in one go, though. You'll have to benchmark this stuff to see what works better.
On parallel colletions, it might pay off to first filter
it for valid annotations. Or, perhaps, collect
.
val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;
or
val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }
Anyway, benchmark.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With