Suppose I have a tree data structure like this:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
Suppose also I've got a function to map over leaves:
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = root match {
case ln: LeafNode => f(ln)
case bn: BranchNode => BranchNode(bn.name, bn.children.map(ch => mapLeaves(ch, f)))
}
Now I am trying to make this function tail-recursive but having a hard time to figure out how to do it. I've read this answer but still don't know to make that binary tree solution work for a multiway tree.
How would you rewrite mapLeaves
to make it tail-recursive?
"Call stack" and "recursion" are merely popular design patterns that later got incorporated into most programming languages (and thus became mostly "invisible"). There is nothing that prevents you from reimplementing both with heap data structures. So, here is "the obvious" 1960's TAOCP retro-style solution:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
case class Frame(name: String, mapped: List[Node], todos: List[Node])
@annotation.tailrec
def step(stack: List[Frame]): Node = stack match {
// "return / pop a stack-frame"
case Frame(name, done, Nil) :: tail => {
val ret = BranchNode(name, done.reverse)
tail match {
case Nil => ret
case Frame(tn, td, tt) :: more => {
step(Frame(tn, ret :: td, tt) :: more)
}
}
}
case Frame(name, done, x :: xs) :: tail => x match {
// "recursion base"
case l @ LeafNode(_) => step(Frame(name, f(l) :: done, xs) :: tail)
// "recursive call"
case BranchNode(n, cs) => step(Frame(n, Nil, cs) :: Frame(name, done, xs) :: tail)
}
case Nil => throw new Error("shouldn't happen")
}
root match {
case l @ LeafNode(_) => f(l)
case b @ BranchNode(n, cs) => step(List(Frame(n, Nil, cs)))
}
}
The tail-recursive step
function takes a reified stack with "stack frames". A "stack frame" stores the name of the branch node that is currently being processed, a list of child nodes that have already been processed, and the list of the remaining nodes that still must be processed later. This roughly corresponds to an actual stack frame of your recursive mapLeaves
function.
With this data structure,
Frame
object, and either returning the final result, or at least making the stack
one frame shorter.Frame
to the stack
f
on leaves) does not create or remove any framesOnce one understands how the usually invisible stack frames are represented explicitly, the translation is straightforward and mostly mechanical.
Example:
val example = BranchNode("x", List(
BranchNode("y", List(
LeafNode("a"),
LeafNode("b")
)),
BranchNode("z", List(
LeafNode("c"),
BranchNode("v", List(
LeafNode("d"),
LeafNode("e")
))
))
))
println(mapLeaves(example, { case LeafNode(n) => LeafNode(n.toUpperCase) }))
Output (indented):
BranchNode(x,List(
BranchNode(y,List(
LeafNode(A),
LeafNode(B)
)),
BranchNode(z, List(
LeafNode(C),
BranchNode(v,List(
LeafNode(D),
LeafNode(E)
))
))
))
It might be easier to implement it using a technique called trampoline.
If you use it, you'd be able to use two functions calling itself doing mutual recursion (with tailrec
, you are limited to one function). Similarly to tailrec
this recursion will be transformed to plain loop.
Trampolines are implemented in scala standard library in scala.util.control.TailCalls
.
import scala.util.control.TailCalls.{TailRec, done, tailcall}
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
//two inner functions doing mutual recursion
//iterates recursively over children of node
def iterate(nodes: List[Node]): TailRec[List[Node]] = {
nodes match {
case x :: xs => tailcall(deepMap(x)) //it calls with mutual recursion deepMap which maps over children of node
.flatMap(node => iterate(xs).map(node :: _)) //you can flat map over TailRec
case Nil => done(Nil)
}
}
//recursively visits all branches
def deepMap(node: Node): TailRec[Node] = {
node match {
case ln: LeafNode => done(f(ln))
case bn: BranchNode => tailcall(iterate(bn.children))
.map(BranchNode(bn.name, _)) //calls mutually iterate
}
}
deepMap(root).result //unwrap result to plain node
}
Instead of TailCalls
you could also use Eval
from Cats
or Trampoline
from scalaz
.
With that implementation function worked without problems:
def build(counter: Int): Node = {
if (counter > 0) {
BranchNode("branch", List(build(counter-1)))
} else {
LeafNode("leaf")
}
}
val root = build(4000)
mapLeaves(root, x => x.copy(name = x.name.reverse)) // no problems
When I ran that example with your implementation it caused java.lang.StackOverflowError
as expected.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With