I'm experimenting with different implementation of Depth-First Search with scalaz.
This traverse should handle wide along with deep tree-like structures.
Main idea - subordinate elements should be generated according to some "state". For example set of marked as seen elements to avoid them in future.
Here the simplest implementation I've come with
import scalaz._
import scalaz.syntax.monad._
import scalaz.State._
abstract class DepthFirstState[E, S] {
def build(elem: E): State[S, List[E]]
def go(start: E): State[S, List[E]] = for {
xs ← build(start)
ys ← xs.traverseSTrampoline(go)
} yield start :: ys.flatten
}
We can create the simplest algorithm to test how it handle deep search
class RangeSearchState extends DepthFirstState[Int, Int] {
def build(elem: Int) = get[Int] map (limit ⇒ if (elem < limit) List(elem + 1) else Nil)
}
It's just a tree degraded to linked list, where each element i
has single child i+1
until it reaches limit
encoded in state. While state is not changing it's more Reader
than State
but it's not the case.
Now
new RangeSearchState go 1 run 100
Successfully builds traversed number list. While
new RangeSearchState go 1 run 1000
Falling with StackOverflowError
.
Is there way to fix implementation of DepthFirstState
so it could run without StackOverflow
even on very deep recursion?
The trampolining that happens in traverseSTrampoline
protects you from overflowing the stack during the traversal. So for example this explodes:
import scalaz._, scalaz.std.list._, scalaz.syntax.traverse._
(0 to 10000).toList.traverseU(_ => State.get[Unit]).run(())
While this doesn't (note that traverseS
is the same as traverseSTrampoline
for State
):
(0 to 10000).toList.traverseS(_ => State.get[Unit]).run(())
You only get this protection during the traversal, though, and in your case the overflow is happening because of the recursive call. You can fix this by doing the trampolining manually:
import scalaz._
import scalaz.std.list._
import scalaz.syntax.traverse._
abstract class DepthFirstState[E, S] {
type TState[s, a] = StateT[Free.Trampoline, s, a]
def build(elem: E): TState[S, List[E]]
def go(start: E): TState[S, List[E]] = for {
xs <- build(start)
ys <- xs.traverseU(go)
} yield start :: ys.flatten
}
class RangeSearchState extends DepthFirstState[Int, Int] {
def build(elem: Int): TState[Int, List[Int]] =
MonadState[TState, Int].get.map(limit =>
if (elem < limit) List(elem + 1) else Nil
)
}
And then:
val (state, result) = (new RangeSearchState).go(1).run(10000).run
It's worth noting that this stack safety is built into State
in cats:
import cats.state.State
import cats.std.function._, cats.std.list._
import cats.syntax.traverse._
abstract class DepthFirstState[E, S] {
def build(elem: E): State[S, List[E]]
def go(start: E): State[S, List[E]] = for {
xs <- build(start)
ys <- xs.traverseU(go)
} yield start :: ys.flatten
}
class RangeSearchState extends DepthFirstState[Int, Int] {
def build(elem: Int): State[Int, List[Int]] =
State.get[Int].map(limit => if (elem < limit) List(elem + 1) else Nil)
}
val (state, result) = (new RangeSearchState).go(1).run(10000).run
This safe-by-default choice is discussed in some detail here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With