Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tail recursive fold on a binary tree in Scala

I am trying to find a tail recursive fold function for a binary tree. Given the following definitions:

// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

Implementing a non tail recursive function is quite straightforward:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
  t match {
    case Leaf(v)      => map(v)
    case Branch(l, r) => 
      red(fold(l)(map)(red), fold(r)(map)(red))
  }

But now I am struggling to find a tail recursive fold function so that the annotation @annotation.tailrec can be used.

During my research I have found several examples where tail recursive functions on a tree can e.g. compute the sum of all leafs using an own stack which is then basically a List[Tree[Int]]. But as far as I understand in this case it only works for the additions because it is not important whether you first evaluate the left or the right hand side of the operator. But for a generalised fold it is quite relevant. To show my intension here are some example trees:

val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)

Based on those trees I want to create a deep copy with the given fold method like:

val oldNewPairs = 
  trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))

And then proof that the condition of equality holds for all created copies:

val conditionHolds = oldNewPairs.forall(p => {
  if (p._1 == p._2) true
  else {
    println(s"Original:\n${p._1}\nNew:\n${p._2}")
    false
  }
})
println("Condition holds: " + conditionHolds)

Could someone give me some pointers, please?

You can find the code used in this question at ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15

like image 582
Henrik Sachse Avatar asked Jan 03 '17 09:01

Henrik Sachse


People also ask

Can binary recursive be tail recursive?

The answer is yes, it is tail recursive. It doesn't do anything with the results of each of its recursive calls, except directly returning those results right away. This means you could replace it with a loop which would update the low and high variables while looping until the stopping condition is met.

What is recursion tail in Scala?

A tail-recursive function is just a function whose very last action is a call to itself. When you write your recursive function in this way, the Scala compiler can optimize the resulting JVM bytecode so that the function requires only one stack frame — as opposed to one stack frame for each level of recursion!

Is fold right tail recursive?

foldRight and reduceRight are in fact tail recursive for Array. It's basically converted into a foldLeft where the index varies in the other direction.

What is tail recursion with example?

Tail recursion is defined as a recursive function in which the recursive call is the last statement that is executed by the function. So basically nothing is left to execute after the recursion call. For example the following C++ function print() is tail recursive.


1 Answers

You could reach a tail recursive solution if you stop using the function call stack and start using a stack managed by your code and an accumulator:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          val leafRes = map(v)
          foldImp(
            toVisit.tail,
            acc :+ leafRes
          )
        case Branch(l, r) =>
          foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.dropRight(2) ++   Vector(acc.takeRight(2).reduce(red)))
      }
    }

  foldImp(t::Nil, Vector.empty).head

}

The idea is to accumulate values from left to right, keep track of the parenthood relation by the introduction of a stub node and reduce the result using your red function using the last two elements of the accumulator whenever a stub node is found in the exploration.

This solution could be optimized but it is already a tail recursive function implementation.

EDIT:

It can be slightly simplified by changing the accumulator data structure to a list seen as a stack:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          foldImp(
            toVisit.tail,
            map(v)::acc 
          )
        case Branch(l, r) =>
          foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
      }
    }

  foldImp(t::Nil, Nil).head

}
like image 73
Pablo Francisco Pérez Hidalgo Avatar answered Sep 21 '22 15:09

Pablo Francisco Pérez Hidalgo