A couple of weeks ago Dragisa Krsmanovic asked a question here about how to use the free monad in Scalaz 7 to avoid stack overflows in this situation (I've adapted his code a bit):
import scalaz._, Scalaz._
def setS(i: Int): State[List[Int], Unit] = modify(i :: _)
val s = (1 to 100000).foldLeft(state[List[Int], Unit](())) {
case (st, i) => st.flatMap(_ => setS(i))
}
s(Nil)
I thought that just lifting a trampoline into StateT
should work:
import Free.Trampoline
val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).lift[Trampoline]) {
case (st, i) => st.flatMap(_ => setS(i).lift[Trampoline])
}
s(Nil).run
But it still blows the stack, so I just posted it as a comment.
Dave Stevens just pointed out that sequencing with the applicative *>
instead of the monadic flatMap
actually works just fine:
val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).lift[Trampoline]) {
case (st, i) => st *> setS(i).lift[Trampoline]
}
s(Nil).run
(Well, it's super slow of course, because that's the price you pay for doing anything interesting like this in Scala, but at least there's no stack overflow.)
What's going on here? I don't think there could be a principled reason for this difference, but really I have no idea what could be going on in the implementation and don't have time to dig around at the moment. But I'm curious and it would be cool if someone else knows.
Mandubian is correct, the flatMap of StateT doesn't allow you to bypass stack accumulation because of the creation of the new StateT immediately before calling the wrapped monad's bind (which would be a Free[Function0] in your case).
So Trampoline can't help, but the Free Monad over the functor for State is one way to ensure stack safety.
We want to go from State[List[Int],Unit] to Free[a[State[List[Int],a],Unit] and our flatMap call will be to Free's flatMap (that doesn't do anything other than create the Free data structure).
val s = (1 to 100000).foldLeft(
Free.liftF[({ type l[a] = State[List[Int],a]})#l,Unit](state[List[Int], Unit](()))) {
case (st, i) => st.flatMap(_ =>
Free.liftF[({ type l[a] = State[List[Int],a]})#l,Unit](setS(i)))
}
Now we have a Free data structure built that we can easily thread a state through as such:
s.foldRun(List[Int]())( (a,b) => b(a) )
Calling liftF is fairly ugly so I have a PR in to make it easier for State and Kleisli monads so hopefully in the future there won't need to be type lambdas.
Edit: PR accepted so now we have
val s = (1 to 100000).foldLeft(state[List[Int], Unit](()).liftF) {
case (st, i) => st.flatMap(_ => setS(i).liftF)
}
There is a principled intuition for this difference.
The applicative operator *>
evaluates its left argument only for its side effects, and always ignores the result. This is similar (in some cases equivalent) to Haskell's >>
function for monads. Here's the source for *>
:
/** Combine `self` and `fb` according to `Apply[F]` with a function that discards the `A`s */
final def *>[B](fb: F[B]): F[B] = F.apply2(self,fb)((_,b) => b)
and Apply#apply2
:
def apply2[A, B, C](fa: => F[A], fb: => F[B])(f: (A, B) => C): F[C] =
ap(fb)(map(fa)(f.curried))
In general, flatMap
depends on the result of the left argument (it must, as it is the input for the function in the right argument). Even though in this specific case you are ignoring the left result, flatMap
doesn't know that.
It seems likely, given your results, that the implementation for *>
is optimized for the case where the result of the left argument is unneeded. However flatMap
cannot perform this optimization and so each call grows the stack by retaining the unused left result.
It's possible that this could be optimized at the compiler (scalac) or JIT (HotSpot) level (Haskell's GHC certainly performs this optimization), but for now this seems like a missed optimization opportunity.
Just to add to the discussion...
In StateT
, you have:
def flatMap[S3, B](f: A => IndexedStateT[F, S2, S3, B])(implicit F: Bind[F]): IndexedStateT[F, S1, S3, B] =
IndexedStateT(s => F.bind(apply(s)) {
case (s1, a) => f(a)(s1)
})
The apply(s)
fixes the current state reference in the next state.
bind
definition interpretes eagerly its parameters catching the reference because it requires it:
def bind[A, B](fa: F[A])(f: A => F[B]): F[B]
At the difference of ap
which might not need to interprete one of its parameters:
def ap[A, B](fa: => F[A])(f: => F[A => B]): F[B]
With this code, the Trampoline
can't help for StateT
flatMap
(and also map
)...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With