Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compose two different `State Monad`?

When I learn State Monad, I'm not sure how to compose two functions with different State return types.

State Monad definition:

case class State[S, A](runState: S => (S, A)) {

  def flatMap[B](f: A => State[S, B]): State[S, B] = {
    State(s => {
      val (s1, a) = runState(s)
      val (s2, b) = f(a).runState(s1)
      (s2, b)
    })
  }

  def map[B](f: A => B): State[S, B] = {
    flatMap(a => {
      State(s => (s, f(a)))
    })
  }

}

Two different State types:

type AppendBang[A] = State[Int, A]

type AddOne[A] = State[String, A]

Two methods with differnt State return types:

def addOne(n: Int): AddOne[Int] = State(s => (s + ".", n + 1))

def appendBang(str: String): AppendBang[String] = State(s => (s + 1, str + " !!!"))

Define a function to use the two functions above:

def myAction(n: Int) = for {
  a <- addOne(n)
  b <- appendBang(a.toString)
} yield (a, b)

And I hope to use it like this:

println(myAction(1))

The problem is myAction is not compilable, it reports some error like this:

Error:(14, 7) type mismatch;
 found   : state_monad.State[Int,(Int, String)]
 required: state_monad.State[String,?]
    b <- appendBang(a.toString)
      ^

How can I fix it? Do I have to define some Monad transformers?


Update: The question may be not clear, let me give an example

Say I want to define another function, which uses addOne and appendBang internally. Since they all need existing states, I have to pass some to it:

def myAction(n: Int)(addOneState: String, appendBangState: Int): ((String, Int), String) = {
  val (addOneState2, n2) = addOne(n).runState(addOneState)
  val (appendBangState2, n3) = appendBang(n2.toString).runState(appendBangState)
  ((addOneState2, appendBangState2), n3)
}

I have to run addOne and appendBang one by one, passing and getting the states and result manually.

Although I found it can return another State, the code is not improved much:

def myAction(n: Int): State[(String, Int), String] = State {  
case (addOneState: String, appendBangState: Int) =>  
  val (addOneState2, n2) = addOne(n).runState(addOneState)  
  val (appendBangState2, n3) = appendBang(n2.toString).runState(  appendBangState)
    ((addOneState2, appendBangState2), n3)
}

Since I'm not quite familiar with them, just wondering is there any way to improve it. The best hope is that I can use for comprehension, but not sure if that's possible

like image 683
Freewind Avatar asked Sep 27 '22 20:09

Freewind


1 Answers

Like I mentioned in my first comment, it will be impossible to use a for comprehension to do what you want, because it can not change the type of the state (S).

Remember that a for comprehension can be translated to a combination of flatMaps, withFilter and one map. If we look at your State.flatMap, it takes a function f to change a State[S,A] into State[S, B]. We can use flatMap and map (and thus a for comprehension) to chain together operations on the same state, but we can't change the type of the state in this chain.

We could generalize your last definition of myAction to combine, compose, ... two functions using state of a different type. We can try to implement this generalized compose method directly in our State class (although this is probably so specific, it probably doesn't belong in State). If we look at State.flatMap and myAction we can see some similarities:

  • We first call runState on our existing State instance.
  • We then call runState again

In myAction we first use the result n2 to create a State[Int, String] (AppendBang[String] or State[S2, B]) using the second function (appendBang or f) on which we then call runState. But our result n2 is of type String (A) and our function appendBang needs an Int (B) so we need a function to convert A into B.

case class State[S, A](runState: S => (S, A)) {
  // flatMap and map

  def compose[B, S2](f: B => State[S2, B], convert: A => B) : State[(S, S2), B] =
    State( ((s: S, s2: S2) => {
      val (sNext, a) = runState(s)
      val (s2Next, b) = f(convert(a)).runState(s2)
      ((sNext, s2Next), b)
    }).tupled)
}

You then could define myAction as :

def myAction(i: Int) = addOne(i).compose(appendBang, _.toString)

val twoStates = myAction(1)
// State[(String, Int),String] = State(<function1>)

twoStates.runState(("", 1))
// ((String, Int), String) = ((.,2),2 !!!)

If you don't want this function in your State class you can create it as an external function :

def combineStateFunctions[S1, S2, A, B](
  a: A => State[S1, A], 
  b: B => State[S2, B], 
  convert: A => B
)(input: A): State[(S1, S2), B] = State( 
  ((s1: S1, s2: S2) => {
    val (s1Next, temp) = a(input).runState(s1)
    val (s2Next, result) = b(convert(temp)).runState(s2)
    ((s1Next, s2Next), result)
  }).tupled
)

def myAction(i: Int) = 
  combineStateFunctions(addOne, appendBang, (_: Int).toString)(i)

Edit : Bergi's idea to create two functions to lift a State[A, X] or a State[B, X] into a State[(A, B), X].

object State {  
  def onFirst[A, B, X](s: State[A, X]): State[(A, B), X] = {
    val runState = (a: A, b: B) => {
      val (nextA, x) = s.runState(a)
      ((nextA, b), x)
    }
    State(runState.tupled)
  }

  def onSecond[A, B, X](s: State[B, X]): State[(A, B), X] = {
    val runState = (a: A, b: B) => {
      val (nextB, x) = s.runState(b)
      ((a, nextB), x)
    }
    State(runState.tupled)
  }
}

This way you can use a for comprehension, since the type of the state stays the same ((A, B)).

def myAction(i: Int) = for {
  x <- State.onFirst(addOne(i))
  y <- State.onSecond(appendBang(x.toString))
} yield y

myAction(1).runState(("", 1))
// ((String, Int), String) = ((.,2),2 !!!)
like image 89
Peter Neyens Avatar answered Oct 03 '22 19:10

Peter Neyens