Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

F# tail call optimization with 2 recursive calls?

As I was writing this function I knew that I wouldn't get tail call optimization. I still haven't come up with a good way of handling this and was hoping someone else might offer suggestions.

I've got a tree:

type Heap<'a> =
| E
| T of int * 'a * Heap<'a> * Heap<'a> 

And I want to count how many nodes are in it:

let count h =
    let rec count' h acc =
        match h with 
        | E -> 0 + acc
        | T(_, value, leftChild, rightChild) ->
            let acc = 1 + acc
            (count' leftChild acc) + (count' rightChild acc)

    count' h 0

This isn't isn't optimized because of the addition of the counts for the child nodes. Any idea of how to make something like this work if the tree has 1 million nodes?

Thanks, Derek


Here is the implementation of count using CPS. It still blew the stack though.

let count h =
    let rec count' h acc cont =
        match h with
        | E -> cont (1 + acc)
        | T(_,_,left,right) ->
            let f = (fun lc -> count' right lc cont)
            count' left acc f

    count' h 0 (fun (x: int) -> x)

Maybe I can come up with some way to partition the tree into enough pieces that I can count without blowing the stack?

Someone asked about the code which generates the tree. It is below.

member this.ParallelHeaps threads =
    let rand = new Random()
    let maxVal = 1000000

    let rec heaper i h =
        if i < 1 then
            h
        else
            let heap = LeftistHeap.insert (rand.Next(100,2 * maxVal)) h
            heaper (i - 1) heap

    let heaps = Array.create threads E
    printfn "Creating heap of %d elements, with %d threads" maxVal threads
    let startTime = DateTime.Now
    seq { for i in 0 .. (threads - 1) ->
          async { Array.set heaps i (heaper (maxVal / threads) E) }}
    |> Async.Parallel
    |> Async.RunSynchronously 
    |> ignore

    printfn "Creating %d sub-heaps took %f milliseconds" threads (DateTime.Now - startTime).TotalMilliseconds
    let startTime = DateTime.Now

    Array.length heaps |> should_ equal threads <| "The size of the heaps array should match the number of threads to process the heaps"

    let rec reMerge i h =
        match i with 
        | -1 -> h
        | _  -> 
            printfn "heap[%d].count = %d" i (LeftistHeap.count heaps.[i])
            LeftistHeap.merge heaps.[i] (reMerge (i-1) h)

    let heap = reMerge (threads-1) E
    printfn "Merging %d heaps took %f milliseconds" threads (DateTime.Now - startTime).TotalMilliseconds
    printfn "heap min: %d" (LeftistHeap.findMin heap)

    LeftistHeap.count heap |> should_ equal maxVal <| "The count of the reMerged heap should equal maxVal"
like image 321
Derek Ealy Avatar asked Jun 11 '11 18:06

Derek Ealy


2 Answers

You can use continuation-passing style (CPS) to solve that problem. See Recursing on Recursion - Continuation Passing by Matthew Podwysocki.

let tree_size_cont tree = 
  let rec size_acc tree acc cont = 
    match tree with 
    | Leaf _ -> cont (1 + acc) 
    | Node(_, left, right) -> 
         size_acc left acc (fun left_size -> 
         size_acc right left_size cont) 

  size_acc tree 0 (fun x -> x)

Note also that in Debug builds, tail call optimization is disabled. If you don't want to run in Release mode, you can enable the optimization in the project's properties in Visual Studio.

like image 54
Joh Avatar answered Sep 19 '22 17:09

Joh


CPS is a good general solution but you might also like to consider explicit use of a stack because it will be faster and is arguably simpler:

let count heap =
  let stack = System.Collections.Generic.Stack[heap]
  let mutable n = 0
  while stack.Count > 0 do
    match stack.Pop() with
    | E -> ()
    | T(_, _, heap1, heap2) ->
        n <- n + 1
        stack.Push heap1
        stack.Push heap2
  n
like image 24
J D Avatar answered Sep 19 '22 17:09

J D