Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filtering Scala's Parallel Collections with early abort when desired number of results found

Given a very large instance of collection.parallel.mutable.ParHashMap (or any other parallel collection), how can one abort a filtering parallel scan once a given, say 50, number of matches has been found ?

Attempting to accumulate intermediate matches in a thread-safe "external" data structure or keeping an external AtomicInteger with result count seems to be 2 to 3 times slower on 4 cores than using a regular collection.mutable.HashMap and pegging a single core at 100%.

I am aware that find or exists on Par* collections do abort "on the inside". Is there a way to generalize this to find more than one result ?

Here's the code which still seems to be 2 to 3 times slower on the ParHashMap with ~ 79,000 entries and also has a problem of stuffing more than maxResults results into the results CHM (Which is probably due to thread being preempted after incrementAndGet but before break which allows other threads to add more elements in). Update: it seems the slow down is due to worker threads contending on the counter.incrementAndGet() which of course defeats the purpose of the whole parallel scan :-(

def find(filter: Node => Boolean, maxResults: Int): Iterable[Node] =
{
  val counter = new AtomicInteger(0)
  val results = new ConcurrentHashMap[Key,  Node](maxResults)

  import util.control.Breaks._

  breakable
  {
    for ((key, node) <- parHashMap if filter(node))
    {
      results.put(key, node)
      val total = counter.incrementAndGet()
      if (total > maxResults) break
    }
  }

  results.values.toArray(new Array[Node](results.size))
}
like image 221
Alex Kravets Avatar asked Nov 09 '11 23:11

Alex Kravets


2 Answers

I would first do parallel scan in which variable maxResults would be threadlocal. This would find up to (maxResults * numberOfThreads) results.

Then I would do single threaded scan to reduce it to maxResults.

like image 148
user482745 Avatar answered Oct 11 '22 13:10

user482745


I had performed an interesting investigation about your case.

Investigation reasoning

I suspected the problem is with the mutability of the input Map and I will try to explain you why: HashMap implementation organizes the data in different buckets, as one can see on Wikipedia.

Wikipedia HashMap

The first thread-safe collections in Java, the synchronized collections were based on synchronizing all the methods around the underlying implementation and resulted in poor performance. Further research and thinking brought to the more performant Concurrent Collection, such as the ConcurrentHashMap which approach was smarter : why don't we protect each bucket with a specific lock?

According to my feeling the performance problem occurs because:

  • when you run in parallel your filter, some threads will conflict on accessing the same bucket at once and will hit the same lock, because your map is mutable.
  • You hold a counter to see how many results you have while you can actually check the size of your result. If you have a thread-safe way to build a collection, you don't need a thread-safe counter too.

Investigation result

I have developed a test case and I find out I was wrong. The problem is with the concurrent nature of the output map. In fact, that is where the collision occurs, when you are putting elements in the map, rather then when you are iterating on it. Additionally, since you want only the result on values, you don't need the keys and the hashing and all the map features. It might be interesting to test if you remove the AtomicCounter and you use only the result map to check if you collected enough elements how your version performs.

Please be careful with the following code in Scala 2.9.2. I am explaining in another post why I need two different functions for the parallel and the non parallel version: Calling map on a parallel collection via a reference to an ancestor type

object MapPerformance {

  val size = 100000
  val items = Seq.tabulate(size)( x => (x,x*2))


  val concurrentParallelMap = ImmutableParHashMap(items:_*)

  val concurrentMutableParallelMap = MutableParHashMap(items:_*)

  val unparallelMap = Map(items:_*)


  class ThreadSafeIndexedSeqBuilder[T](maxSize:Int) {
    val underlyingBuilder = new VectorBuilder[T]()
    var counter = 0
    def sizeHint(hint:Int) { underlyingBuilder.sizeHint(hint) }
    def +=(item:T):Boolean ={
      synchronized{
        if(counter>=maxSize)
          false
        else{
          underlyingBuilder+=item
          counter+=1
          true
        }
      }
    }
    def result():Vector[T] = underlyingBuilder.result()

  }

  def find(map:ParMap[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
  {

    // we already know the maximum size
    val resultsBuilder = new ThreadSafeIndexedSeqBuilder[Int](maxResults)
    resultsBuilder.sizeHint(maxResults)

    import util.control.Breaks._
    breakable
    {
      for ((key, node) <- map if filter(node))
      {
        val newItemAdded = resultsBuilder+=node
        if (!newItemAdded)
          break()

      }
    }
    resultsBuilder.result().seq

  }

  def findUnParallel(map:Map[Int,Int],filter: Int => Boolean, maxResults: Int): Iterable[Int] =
  {

    // we already know the maximum size
    val resultsBuilder = Array.newBuilder[Int]
    resultsBuilder.sizeHint(maxResults)

    var counter = 0
      for {
        (key, node) <- map if filter(node)
        if counter < maxResults
      }{
        resultsBuilder+=node
        counter+=1
      }

    resultsBuilder.result()

  }

  def measureTime[K](f: => K):(Long,K) = {
    val startMutable = System.currentTimeMillis()
    val result = f
    val endMutable = System.currentTimeMillis()
    (endMutable-startMutable,result)
  }

  def main(args:Array[String]) = {
    val maxResultSetting=10
    (1 to 10).foreach{
      tryNumber =>
        println("Try number " +tryNumber)
        val (mutableTime, mutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
        val (immutableTime, immutableResult) = measureTime(find(concurrentMutableParallelMap,_%2==0,maxResultSetting))
        val (unparallelTime, unparallelResult) = measureTime(findUnParallel(unparallelMap,_%2==0,maxResultSetting))
        assert(mutableResult.size==maxResultSetting)
        assert(immutableResult.size==maxResultSetting)
        assert(unparallelResult.size==maxResultSetting)
        println(" The mutable version has taken " + mutableTime + " milliseconds")
        println(" The immutable version has taken " + immutableTime + " milliseconds")
        println(" The unparallel version has taken " + unparallelTime + " milliseconds")
     }
  }

}

With this code, I have systematically the parallel (both mutable and immutable version of the input map) about 3,5 time faster then the unparallel on my machine.

like image 25
Edmondo1984 Avatar answered Oct 11 '22 15:10

Edmondo1984