Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scalaz.State stack overflow in deep monadic loop

I'm experimenting with different implementation of Depth-First Search with scalaz.

This traverse should handle wide along with deep tree-like structures.

Main idea - subordinate elements should be generated according to some "state". For example set of marked as seen elements to avoid them in future.

Here the simplest implementation I've come with

import scalaz._
import scalaz.syntax.monad._
import scalaz.State._

abstract class DepthFirstState[E, S] {
  def build(elem: E): State[S, List[E]]

  def go(start: E): State[S, List[E]] = for {
    xs ← build(start)
    ys ← xs.traverseSTrampoline(go)
  } yield start :: ys.flatten
}

We can create the simplest algorithm to test how it handle deep search

class RangeSearchState extends DepthFirstState[Int, Int] {
  def build(elem: Int) = get[Int] map (limit ⇒ if (elem < limit) List(elem + 1) else Nil)
}

It's just a tree degraded to linked list, where each element i has single child i+1 until it reaches limit encoded in state. While state is not changing it's more Reader than State but it's not the case.

Now

new RangeSearchState go 1 run 100

Successfully builds traversed number list. While

new RangeSearchState go 1 run 1000

Falling with StackOverflowError.

Is there way to fix implementation of DepthFirstState so it could run without StackOverflow even on very deep recursion?

like image 913
Odomontois Avatar asked Dec 15 '22 10:12

Odomontois


1 Answers

The trampolining that happens in traverseSTrampoline protects you from overflowing the stack during the traversal. So for example this explodes:

import scalaz._, scalaz.std.list._, scalaz.syntax.traverse._

(0 to 10000).toList.traverseU(_ => State.get[Unit]).run(())

While this doesn't (note that traverseS is the same as traverseSTrampoline for State):

(0 to 10000).toList.traverseS(_ => State.get[Unit]).run(())

You only get this protection during the traversal, though, and in your case the overflow is happening because of the recursive call. You can fix this by doing the trampolining manually:

import scalaz._
import scalaz.std.list._
import scalaz.syntax.traverse._

abstract class DepthFirstState[E, S] {
  type TState[s, a] = StateT[Free.Trampoline, s, a]

  def build(elem: E): TState[S, List[E]]

  def go(start: E): TState[S, List[E]] = for {
    xs <- build(start)
    ys <- xs.traverseU(go)
  } yield start :: ys.flatten
}

class RangeSearchState extends DepthFirstState[Int, Int] {
  def build(elem: Int): TState[Int, List[Int]] =
    MonadState[TState, Int].get.map(limit =>
      if (elem < limit) List(elem + 1) else Nil
    )
}

And then:

val (state, result) = (new RangeSearchState).go(1).run(10000).run

It's worth noting that this stack safety is built into State in cats:

import cats.state.State
import cats.std.function._, cats.std.list._
import cats.syntax.traverse._

abstract class DepthFirstState[E, S] {
  def build(elem: E): State[S, List[E]]

  def go(start: E): State[S, List[E]] = for {
    xs <- build(start)
    ys <- xs.traverseU(go)
  } yield start :: ys.flatten
}

class RangeSearchState extends DepthFirstState[Int, Int] {
  def build(elem: Int): State[Int, List[Int]] =
    State.get[Int].map(limit => if (elem < limit) List(elem + 1) else Nil)
}

val (state, result) = (new RangeSearchState).go(1).run(10000).run

This safe-by-default choice is discussed in some detail here.

like image 56
Travis Brown Avatar answered Dec 21 '22 12:12

Travis Brown