Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

N-Tree Traversal with Scala Causes Stack Overflow

I am attempting to return a list of widgets from an N-tree data structure. In my unit test, if i have roughly about 2000 widgets each with a single dependency, i'll encounter a stack overflow. What I think is happening is the for loop is causing my tree traversal to not be tail recursive. what's a better way of writing this in scala? Here's my function:

protected def getWidgetTree(key: String) : ListBuffer[Widget] = {
  def traverseTree(accumulator: ListBuffer[Widget], current: Widget) : ListBuffer[Widget] = {
    accumulator.append(current)

    if (!current.hasDependencies) {
      accumulator
    }  else {
      for (dependencyKey <- current.dependencies) {
        if (accumulator.findIndexOf(_.name == dependencyKey) == -1) {
          traverseTree(accumulator, getWidget(dependencyKey))
        }
      }

      accumulator
    }
  }

  traverseTree(ListBuffer[Widget](), getWidget(key))
}
like image 815
Donuts Avatar asked Oct 23 '12 22:10

Donuts


2 Answers

For the same example as in @dhg's answer, an equivalent tail recursive function with no mutable state (the ListBuffer) would be:

case class Widget(name: String, dependencies: List[String])

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List())).map { w => w.name -> w }.toMap

def getWidgetTree(key: String): List[Widget] = {
  def addIfNotAlreadyContained(widgetList: List[Widget], widgetNameToAdd: String): List[Widget] = {
    if (widgetList.find(_.name == widgetNameToAdd).isDefined) widgetList
    else                                                      widgetList :+ getWidget(widgetNameToAdd)
  }

  @tailrec
  def traverseTree(currentWidgets: List[Widget], acc: List[Widget]): List[Widget] = currentWidgets match {
    case Nil                                => {
      // If there are no more widgets in this branch return what we've traversed so far
      acc 
    }
    case Widget(name, Nil) :: rest          => {
      // If the first widget is a leaf traverse the rest and add the leaf to the list of traversed
      traverseTree(rest, addIfNotAlreadyContained(acc, name)) 
    }
    case Widget(name, dependencies) :: rest => {
      // If the first widget is a parent, traverse it's children and the rest and add it to the list of traversed
      traverseTree(dependencies.map(getWidget) ++ rest, addIfNotAlreadyContained(acc, name))
    } 
  }

  val root = getWidget(key)
  traverseTree(root.dependencies.map(getWidget) :+ root, List[Widget]())
}

For the same test case

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name).toList.sorted)

Gives you:

List(2, 3, 4, 5, 1)
List(3, 4, 2)
List(3)
List(4)
List(5)

Note that this is postorder not preorder traversal.

like image 23
Ratan Sebastian Avatar answered Sep 30 '22 18:09

Ratan Sebastian


The reason it's not tail-recursive is that you are making multiple recursive calls inside your function. To be tail-recursive, a recursive call can only be the last expression in the function body. After all, the whole point is that it works like a while-loop (and, thus, can be transformed into a loop). A loop can't call itself multiple times within a single iteration.

To do a tree traversal like this, you can use a queue to carry forward the nodes that need to be visited.

Assume we have this tree:

//        1
//       / \  
//      2   5
//     / \
//    3   4

Represented with this simple data structure:

case class Widget(name: String, dependencies: List[String]) {
  def hasDependencies = dependencies.nonEmpty
}

And we have this map pointing to each node:

val getWidget = List(
  Widget("1", List("2", "5")),
  Widget("2", List("3", "4")),
  Widget("3", List()),
  Widget("4", List()),
  Widget("5", List()))
  .map { w => w.name -> w }.toMap

Now we can rewrite your method to be tail-recursive:

def getWidgetTree(key: String): List[Widget] = {
  @tailrec
  def traverseTree(queue: List[String], accumulator: List[Widget]): List[Widget] = {
    queue match {
      case currentKey :: queueTail =>        // the queue is not empty
        val current = getWidget(currentKey)  // get the element at the front
        val newQueueItems =                  // filter out the dependencies already known
          current.dependencies.filterNot(dependencyKey => 
            accumulator.exists(_.name == dependencyKey) && !queue.contains(dependencyKey))
        traverseTree(newQueueItems ::: queueTail, current :: accumulator) // 
      case Nil =>                            // the queue is empty
        accumulator.reverse                  // we're done
    }
  }

  traverseTree(key :: Nil, List[Widget]())
}

And test it out:

for (k <- 1 to 5)
  println(getWidgetTree(k.toString).map(_.name))

prints:

ListBuffer(1, 2, 3, 4, 5)
ListBuffer(2, 3, 4)
ListBuffer(3)
ListBuffer(4)
ListBuffer(5)
like image 111
dhg Avatar answered Sep 30 '22 16:09

dhg