Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Functional style early exit from depth-first recursion

I have a question about writing recursive algorithms in a functional style. I will use Scala for my example here, but the question applies to any functional language.

I am doing a depth-first enumeration of an n-ary tree where each node has a label and a variable number of children. Here is a simple implementation that prints the labels of the leaf nodes.

case class Node[T](label:T, ns:Node[T]*)
def dfs[T](r:Node[T]):Seq[T] = {
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n)) yield c
}
val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r) // returns Seq[Symbol] = ArrayBuffer('d, 'f, 'c)

Now say that sometimes I want to be able to give up on parsing oversize trees by throwing an exception. Is this possible in a functional language? Specifically is this possible without using mutable state? That seems to depend on what you mean by "oversize". Here is a purely functional version of the algorithm that throws an exception when it tries to handle a tree with a depth of 3 or greater.

def dfs[T](r:Node[T], d:Int = 0):Seq[T] = {
    require(d < 3)
    if (r.ns.isEmpty) Seq(r.label) else for (n<-r.ns;c<-dfs(n, d+1)) yield c
}

But what if a tree is oversized because it is too broad rather than too deep? Specifically what if I want to throw an exception the n-th time the dfs() function is called recursively regardless of how deep the recursion goes? The only way I can see how to do this is to have a mutable counter that is incremented with each call. I can't see how to do it without a mutable variable.

I'm new to functional programming and have been working under the assumption that anything you can do with mutable state can be done without, but I don't see the answer here. The only thing I can think to do is write a version of dfs() that returns a view over all the nodes in the tree in depth-first order.

dfs[T](r:Node[T]):TraversableView[T, Traversable[_]] = ...

Then I could impose my limit by saying dfs(r).take(n), but I don't see how to write this function. In Python I'd just create a generator by yielding nodes as I visited them, but I don't see how to achieve the same effect in Scala. (Scala's equivalent to a Python-style yield statement appears to be a visitor function passed in as a parameter, but I can't figure out how to write one of these that will generate a sequence view.)

EDIT Getting close to the answer.

Here is an function that returns a Stream of nodes in depth-first order.

def dfs[T](r: Node[T]): Stream[Node[T]] = {
    (r #:: Stream.empty /: r.ns)(_ ++ dfs(_))
}

That is almost it. The only problem is that Stream memoizes all results, which is a waste of memory. I want a traversable view. The following is the idea, but does not compile.

def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] = {
    (Traversable(r).view /: r.ns)(_ ++ dfs(_))
}

It gives a "found TraversableView[Node[T], Traversable[Node[T]]], required TraversableView[Node[T], Traversable[_]] error for the ++ operator. If I change the return type to TraversableView[Node[T], Traversable[_]], I get the same problem with the "found" and "required" clauses switched. So there's some magic type variance incantation I haven't lit upon yet, but this is close.

like image 667
W.P. McNeill Avatar asked Nov 20 '12 23:11

W.P. McNeill


2 Answers

It can be done: you just have to write some code to actually iterate through the children in the way you want (as opposed to relying on for).

More explicitly, you'll have to write code to iterate through a list of children and check if the "depth" crossed your threshold. Here's some Haskell code (I'm really sorry, I'm not fluent in Scala, but this can probably be easily transliterated):

http://ideone.com/O5gvhM

In this code, I've basically replaced the for loop for an explicit recursive version. This allows me to stop the recursion if the number of visited nodes is already too deep (i.e., limit is not positive). When I recurse to examine the next child, I subtract the number of nodes the dfs of the previous child visited and set this as the limit for the next child.

Functional languages are fun, but they're a huge leap from imperative programming. It really makes you pay attention to the concept of state, because all of it is excruciatingly explicit in the arguments when you go functional.

EDIT: Explaining this a bit more.

I ended up converting from "print just the leaf nodes" (which was the original algorithm from the OP) to "print all nodes". This enabled me to have access to the number of nodes the subcall visited through the length of the resulting list. If you want to stick to the leaf nodes, you'll have to carry around how many nodes you have already visited:

http://ideone.com/cIQrna

EDIT again To clear up this answer, I'm putting all the Haskell code on ideone, and I've transliterated my Haskell code to Scala, so this can stay here as the definite answer to the question:

case class Node[T](label:T, children:Seq[Node[T]])

case class TraversalResult[T](num_visited:Int, labels:Seq[T])

def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit => 
            node.children match {
                case Nil => TraversalResult(1, List(node.label))
                case children => {
                    val result = traverse(node.children, limit - 1)
                    TraversalResult(result.num_visited + 1, result.labels)
                }
            }
    }

def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit =>
            children match {
                case Nil => TraversalResult(0, Nil)
                case first :: rest => {
                    val trav_first = dfs(first, limit)
                    val trav_rest = 
                        traverse(rest, limit - trav_first.num_visited)
                    TraversalResult(
                        trav_first.num_visited + trav_rest.num_visited,
                        trav_first.labels ++ trav_rest.labels
                    )
                }
            }
    }

val n = Node(0, List(
    Node(1, List(Node(2, Nil), Node(3, Nil))),
    Node(4, List(Node(5, List(Node(6, Nil))))),
    Node(7, Nil)
))
for (i <- 1 to 8)
    println(dfs(n, i))

Output:

TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))

P.S. this is my first attempt at Scala, so the above probably contains some horrid non-idiomatic code. I'm sorry.

like image 165
Cesar Kawakami Avatar answered Nov 11 '22 09:11

Cesar Kawakami


You can convert breadth into depth by passing along an index or taking the tail:

def suml(xs: List[Int], total: Int = 0) = xs match {
  case Nil => total
  case x :: rest => suml(rest, total+x)
}

def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
  if (from >= xs.length) total
  else suma(xs, from+1, total + xs(from))
}

In the latter case, you already have something to limit your breadth if you want; in the former, just add a width or somesuch.

like image 39
Rex Kerr Avatar answered Nov 11 '22 07:11

Rex Kerr