I am trying to find a tail recursive fold function for a binary tree. Given the following definitions:
// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
Implementing a non tail recursive function is quite straightforward:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
t match {
case Leaf(v) => map(v)
case Branch(l, r) =>
red(fold(l)(map)(red), fold(r)(map)(red))
}
But now I am struggling to find a tail recursive fold function so that the annotation @annotation.tailrec
can be used.
During my research I have found several examples where tail recursive functions on a tree can e.g. compute the sum of all leafs using an own stack which is then basically a List[Tree[Int]]
. But as far as I understand in this case it only works for the additions because it is not important whether you first evaluate the left or the right hand side of the operator. But for a generalised fold it is quite relevant. To show my intension here are some example trees:
val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)
Based on those trees I want to create a deep copy with the given fold method like:
val oldNewPairs =
trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))
And then proof that the condition of equality holds for all created copies:
val conditionHolds = oldNewPairs.forall(p => {
if (p._1 == p._2) true
else {
println(s"Original:\n${p._1}\nNew:\n${p._2}")
false
}
})
println("Condition holds: " + conditionHolds)
Could someone give me some pointers, please?
You can find the code used in this question at ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15
The answer is yes, it is tail recursive. It doesn't do anything with the results of each of its recursive calls, except directly returning those results right away. This means you could replace it with a loop which would update the low and high variables while looping until the stopping condition is met.
A tail-recursive function is just a function whose very last action is a call to itself. When you write your recursive function in this way, the Scala compiler can optimize the resulting JVM bytecode so that the function requires only one stack frame — as opposed to one stack frame for each level of recursion!
foldRight and reduceRight are in fact tail recursive for Array. It's basically converted into a foldLeft where the index varies in the other direction.
Tail recursion is defined as a recursive function in which the recursive call is the last statement that is executed by the function. So basically nothing is left to execute after the recursion call. For example the following C++ function print() is tail recursive.
You could reach a tail recursive solution if you stop using the function call stack and start using a stack managed by your code and an accumulator:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {
case object BranchStub extends Tree[Nothing]
@tailrec
def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
if(toVisit.isEmpty) acc
else {
toVisit.head match {
case Leaf(v) =>
val leafRes = map(v)
foldImp(
toVisit.tail,
acc :+ leafRes
)
case Branch(l, r) =>
foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
case BranchStub =>
foldImp(toVisit.tail, acc.dropRight(2) ++ Vector(acc.takeRight(2).reduce(red)))
}
}
foldImp(t::Nil, Vector.empty).head
}
The idea is to accumulate values from left to right, keep track of the parenthood relation by the introduction of a stub node and reduce the result using your red
function using the last two elements of the accumulator whenever a stub node is found in the exploration.
This solution could be optimized but it is already a tail recursive function implementation.
EDIT:
It can be slightly simplified by changing the accumulator data structure to a list seen as a stack:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {
case object BranchStub extends Tree[Nothing]
@tailrec
def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
if(toVisit.isEmpty) acc
else {
toVisit.head match {
case Leaf(v) =>
foldImp(
toVisit.tail,
map(v)::acc
)
case Branch(l, r) =>
foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
case BranchStub =>
foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
}
}
foldImp(t::Nil, Nil).head
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With