Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Applicative vs. monadic combinators and the free monad in Scalaz

A couple of weeks ago Dragisa Krsmanovic asked a question here about how to use the free monad in Scalaz 7 to avoid stack overflows in this situation (I've adapted his code a bit):

import scalaz._, Scalaz._

def setS(i: Int): State[List[Int], Unit] = modify(i :: _)

val s = (1 to 100000).foldLeft(state[List[Int], Unit](())) {
  case (st, i) => st.flatMap(_ => setS(i))
}

s(Nil)

I thought that just lifting a trampoline into StateT should work:

import Free.Trampoline

val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).lift[Trampoline]) {
  case (st, i) => st.flatMap(_ => setS(i).lift[Trampoline])
}

s(Nil).run

But it still blows the stack, so I just posted it as a comment.

Dave Stevens just pointed out that sequencing with the applicative *> instead of the monadic flatMap actually works just fine:

val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).lift[Trampoline]) {
  case (st, i) => st *> setS(i).lift[Trampoline]
}

s(Nil).run

(Well, it's super slow of course, because that's the price you pay for doing anything interesting like this in Scala, but at least there's no stack overflow.)

What's going on here? I don't think there could be a principled reason for this difference, but really I have no idea what could be going on in the implementation and don't have time to dig around at the moment. But I'm curious and it would be cool if someone else knows.

like image 428
Travis Brown Avatar asked Jun 10 '14 22:06

Travis Brown


3 Answers

Mandubian is correct, the flatMap of StateT doesn't allow you to bypass stack accumulation because of the creation of the new StateT immediately before calling the wrapped monad's bind (which would be a Free[Function0] in your case).

So Trampoline can't help, but the Free Monad over the functor for State is one way to ensure stack safety.

We want to go from State[List[Int],Unit] to Free[a[State[List[Int],a],Unit] and our flatMap call will be to Free's flatMap (that doesn't do anything other than create the Free data structure).

val s = (1 to 100000).foldLeft( 
    Free.liftF[({ type l[a] = State[List[Int],a]})#l,Unit](state[List[Int], Unit](()))) {
      case (st, i) => st.flatMap(_ => 
          Free.liftF[({ type l[a] = State[List[Int],a]})#l,Unit](setS(i)))
    }

Now we have a Free data structure built that we can easily thread a state through as such:

s.foldRun(List[Int]())( (a,b) => b(a) )

Calling liftF is fairly ugly so I have a PR in to make it easier for State and Kleisli monads so hopefully in the future there won't need to be type lambdas.

Edit: PR accepted so now we have

val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).liftF) {
      case (st, i) => st.flatMap(_ => setS(i).liftF)
}
like image 146
Vincent Avatar answered Oct 23 '22 20:10

Vincent


There is a principled intuition for this difference.

The applicative operator *> evaluates its left argument only for its side effects, and always ignores the result. This is similar (in some cases equivalent) to Haskell's >> function for monads. Here's the source for *>:

/** Combine `self` and `fb` according to `Apply[F]` with a function that discards the `A`s */
final def *>[B](fb: F[B]): F[B] = F.apply2(self,fb)((_,b) => b)

and Apply#apply2:

def apply2[A, B, C](fa: => F[A], fb: => F[B])(f: (A, B) => C): F[C] =
  ap(fb)(map(fa)(f.curried))

In general, flatMap depends on the result of the left argument (it must, as it is the input for the function in the right argument). Even though in this specific case you are ignoring the left result, flatMap doesn't know that.

It seems likely, given your results, that the implementation for *> is optimized for the case where the result of the left argument is unneeded. However flatMap cannot perform this optimization and so each call grows the stack by retaining the unused left result.

It's possible that this could be optimized at the compiler (scalac) or JIT (HotSpot) level (Haskell's GHC certainly performs this optimization), but for now this seems like a missed optimization opportunity.

like image 31
cdk Avatar answered Oct 23 '22 19:10

cdk


Just to add to the discussion...

In StateT, you have:

  def flatMap[S3, B](f: A => IndexedStateT[F, S2, S3, B])(implicit F: Bind[F]): IndexedStateT[F, S1, S3, B] = 
  IndexedStateT(s => F.bind(apply(s)) {
    case (s1, a) => f(a)(s1)
  })

The apply(s) fixes the current state reference in the next state.

bind definition interpretes eagerly its parameters catching the reference because it requires it:

  def bind[A, B](fa: F[A])(f: A => F[B]): F[B]

At the difference of ap which might not need to interprete one of its parameters:

  def ap[A, B](fa: => F[A])(f: => F[A => B]): F[B]

With this code, the Trampoline can't help for StateT flatMap (and also map)...

like image 3
mandubian Avatar answered Oct 23 '22 21:10

mandubian