Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Thread-safely transforming a value in a mutable map

Suppose I want to use a mutable map in Scala to keep track of the number of times I've seen some strings. In a single-threaded context, this is easy:

import scala.collection.mutable.{ Map => MMap }

class Counter {
  val counts = MMap.empty[String, Int].withDefaultValue(0)

  def add(s: String): Unit = counts(s) += 1
}

Unfortunately this isn't thread-safe, since the get and the update don't happen atomically.

Concurrent maps add a few atomic operations to the mutable map API, but not the one I need, which would look something like this:

def replace(k: A, f: B => B): Option[B]

I know I can use ScalaSTM's TMap:

import scala.concurrent.stm._

class Counter {
  val counts =  TMap.empty[String, Int]

  def add(s: String): Unit = atomic { implicit txn =>
    counts(s) = counts.get(s).getOrElse(0) + 1
  }
}

But (for now) that's still an extra dependency. Other options would include actors (another dependency), synchronization (potentially less efficient), or Java's atomic references (less idiomatic).

In general I'd avoid mutable maps in Scala, but I've occasionally needed this kind of thing, and most recently I've used the STM approach (instead of just crossing my fingers and hoping I don't get bitten by the naïve solution).

I know there are a number of trade-offs here (extra dependencies vs. performance vs. clarity, etc.), but is there anything like a "right" answer to this problem in Scala 2.10?

like image 314
Travis Brown Avatar asked Aug 09 '13 14:08

Travis Brown


4 Answers

How about this one? Assuming you don't really need a general replace method right now, just a counter.

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

object CountedMap {
  private val counts = new ConcurrentHashMap[String, AtomicInteger]

  def add(key: String): Int = {
    val zero = new AtomicInteger(0)
    val value = Option(counts.putIfAbsent(key, zero)).getOrElse(zero)
    value.incrementAndGet
  }
}

You get better performance than synchronizing on the whole map, and you also get atomic increments.

like image 181
Ionuț G. Stan Avatar answered Sep 30 '22 18:09

Ionuț G. Stan


The simplest solution is definitely synchronization. If there is not too much contention, performance might not be that bad.

Otherwise, you could try to roll up your own STM-like replace implementation. Something like this might do:

object ConcurrentMapOps {
  private val rng = new util.Random
  private val MaxReplaceRetryCount = 10
  private val MinReplaceBackoffTime: Long = 1
  private val MaxReplaceBackoffTime: Long = 20
}
implicit class ConcurrentMapOps[A, B]( val m: collection.concurrent.Map[A,B] ) {
  import ConcurrentMapOps._
  private def replaceBackoff() {
    Thread.sleep( (MinReplaceBackoffTime + rng.nextFloat * (MaxReplaceBackoffTime - MinReplaceBackoffTime) ).toLong ) // A bit crude, I know
  }

  def replace(k: A, f: B => B): Option[B] = {
    m.get( k ) match {
      case None => return None
      case Some( old ) =>
        var retryCount = 0
        while ( retryCount <= MaxReplaceRetryCount ) {
          val done = m.replace( k, old, f( old ) )
          if ( done ) {
            return Some( old )
          }
          else {         
            retryCount += 1
            replaceBackoff()
          }
        }
        sys.error("Could not concurrently modify map")
    }
  }
}

Note that collision issues are localized to a given key. If two threads access the same map but work on distinct keys, you'll have no collisions and the replace operation will always succeed the first time. If a collision is detected, we wait a bit (a random amount of time, so as to minimize the likeliness of threads fighting forever for the same key) and try again.

I cannot guarantee that this is production-ready (I just tossed it right now), but that might do the trick.

UPDATE: Of course (as Ionuț G. Stan pointed out), if all you want is increment/decrement a value, java's ConcurrentHashMap already provides thoses operations in a lock-free manner. My above solution applies if you need a more general replace method that would take the transformation function as a parameter.

like image 41
Régis Jean-Gilles Avatar answered Sep 30 '22 19:09

Régis Jean-Gilles


You're asking for trouble if your map is just sitting there as a val. If it meets your use case, I'd recommend something like

class Counter {
  private[this] myCounts = MMap.empty[String, Int].withDefaultValue(0)
  def counts(s: String) = myCounts.synchronized { myCounts(s) }
  def add(s: String) = myCounts.synchronized { myCounts(s) += 1 }
  def getCounts = myCounts.synchronized { Map[String,Int]() ++ myCounts }
}

for low-contention usage. For high-contention, you should use a concurrent map designed to support such use (e.g. java.util.concurrent.ConcurrentHashMap) and wrap the values in AtomicWhatever.

like image 21
Rex Kerr Avatar answered Sep 30 '22 20:09

Rex Kerr


If you are ok to work with future based interface:

trait SingleThreadedExecutionContext {
  val ec = ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor())
}

class Counter extends SingleThreadedExecutionContext {
  private val counts = MMap.empty[String, Int].withDefaultValue(0)

  def get(s: String): Future[Int] = future(counts(s))(ec)

  def add(s: String): Future[Unit] = future(counts(s) += 1)(ec)
}

Test will look like:

class MutableMapSpec extends Specification {

  "thread safe" in {

    import ExecutionContext.Implicits.global

    val c = new Counter
    val testData = Seq.fill(16)("1")
    await(Future.traverse(testData)(c.add))
    await(c.get("1")) mustEqual 16
  }
}
like image 41
Mushtaq Ahmed Avatar answered Sep 30 '22 18:09

Mushtaq Ahmed