Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala State monad - combining different state types

I'm wrapping my head around State monad. Trivial examples are easy to understand. I'm now moving to a real world case where the domain objects are composite. For example, with the following domain objects (they don't make much sense, just sheer example):

case class Master(workers: Map[String, Worker])
case class Worker(elapsed: Long, result: Vector[String])
case class Message(workerId: String, work: String, elapsed: Long)

Considering Worker as S types in State[S, +A] monad it's quite easy to write a few combinators like these:

type WorkerState[+A] = State[Worker, A]
def update(message: Message): WorkerState[Unit] = State.modify { w =>
    w.copy(elapsed = w.elapsed + message.elapsed,
           result = w.result :+ message.work)
}
def getWork: WorkerState[Vector[String]] = State { w => (w.result, w) }
def getElapsed: WorkerState[Long] = State { w => (w.elapsed, w) }
def updateAndGetElapsed(message: Message): WorkerState[Long] = for {
    _ <- update(message)
    elapsed <- getElapsed
} yield elapsed
// etc.

What is the idiomatic way to combine these with the Master state combinators? e.g.

type MasterState[+A] = State[Master, A]
def updateAndGetElapsedTime(message: Message): MasterState[Option[Long]]

I can implement this like so:

def updateAndGetElapsedTime(message: Message): MasterState[Option[Long]] =   
    State { m =>
        m.workers.get(message.workerId) match {
            case None => (None, m)
            case Some(w) =>
                val (t, newW) = updateAndGetElapsed(message).run(w)
                (Some(t), m.copy(m.workers.updated(message.workerId, newW))
        }
    }

What I don't like is that I have to manually run the State monad inside the last transformer. My real world example is a bit more involved. With this approach it quickly gets messy.

Is there more idiomatic way to run this sort of incremental updates?

like image 567
ak. Avatar asked Jun 02 '15 04:06

ak.


1 Answers

It's possible to do this pretty nicely by combining lenses and the state monad. First for the setup (I've edited yours lightly to get it to compile with Scalaz 7.1):

case class Master(workers: Map[String, Worker])
case class Worker(elapsed: Long, result: Vector[String])
case class Message(workerId: String, work: String, elapsed: Long)

import scalaz._, Scalaz._

type WorkerState[A] = State[Worker, A]

def update(message: Message): WorkerState[Unit] = State.modify { w =>
  w.copy(
    elapsed = w.elapsed + message.elapsed,
    result = w.result :+ message.work
  )
}

def getWork: WorkerState[Vector[String]] = State.gets(_.result)
def getElapsed: WorkerState[Long] = State.gets(_.elapsed)
def updateAndGetElapsed(message: Message): WorkerState[Long] = for {
  _ <- update(message)
  elapsed <- getElapsed
} yield elapsed

And now for a couple of general purpose lenses that allow us to look inside a Master:

val workersLens: Lens[Master, Map[String, Worker]] = Lens.lensu(
  (m, ws) => m.copy(workers = ws),
  _.workers
)

def workerLens(workerId: String): PLens[Master, Worker] =
  workersLens.partial andThen PLens.mapVPLens(workerId)

And then we're basically done:

def updateAndGetElapsedTime(message: Message): State[Master, Option[Long]] =
  workerLens(message.workerId) %%= updateAndGetElapsed(message)

Here the %%= just tells us what state operation to perform once we've zoomed in to the appropriate worker via our lens.

like image 105
Travis Brown Avatar answered Oct 13 '22 14:10

Travis Brown