Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I convert this binary recursive function into a tail-recursive form?

There is a clear way to convert binary recursion to tail recursion for sets closed under a function, i.e. integers with addition for the Fibonacci sequence:

(Using Haskell)

fib :: Int -> Int
fib n = fib' 0 1 n

fib' :: Int -> Int -> Int
fib' x y n
    | n < 1 = y  
    | otherwise = fib' y (x + y) (n - 1)

This works because we have our desired value, y, and our operation, x + y, where x + y returns an integer just like y does.

However, what if I want to use a set that is not closed under a function? I want to take a function that splits a list into two lists and then does the same to those two lists (i.e. like recursively creating a binary tree), where I stop when another function magically says when to stop when it looks at the resulting split:

[1, 2, 3, 4, 5] -> [[1, 3, 4], [2, 5]] -> [[1, 3], [4], [2], [5]]

That is,

splitList :: [Int] -> [[Int]]
splitList intList
    | length intList < 2    = [intList]
    | magicFunction x y > 0 = splitList x ++ splitList y
    | otherwise             = [intList]
  where
    x = some sublist of intList
    y = the other sublist of intList

Now, how can this binary recursion be converted to tail recursion? The prior method won't explicitly work, as (Int + Int -> Int is the same as the inputs) but (Split [Int] -/> [[Int]] is not the same as the input). As such, the accumulator would need to be changed (I assume).

like image 467
user1104160 Avatar asked May 11 '13 04:05

user1104160


2 Answers

There is a general trick to make any function tail recursive: rewrite it in continuation-passing style (CPS). The basic idea behind CPS is that every function takes an additional parameter--a function to call when they're done. Then, instead of returning a value, the original functions calls the function that was passed in. This latter function is called a "continuation" because it continues the computation on to its next step.

To illustrate this idea, I'm just going to use your function as an example. Note the changes to the type signature as well as the structure of the code:

splitListCPS :: [Int] -> ([[Int]] -> r) -> r
splitListCPS intList cont
  | length intList < 2    = cont [intList]
  | magicFunction x y > 0 = splitListCPS x $ \ r₁ -> 
                              splitListCPS y $ \ r₂ -> 
                                cont $ r₁ ++ r₂
  | otherwise             = cont [intList]

You can then wrap this up into a normal-looking function as follows:

splitList :: [Int] -> [[Int]]
splitList intList = splitListCPS intList (\ r -> r)

If you follow the slightly convoluted logic, you'll see that these two functions are equivalent. The tricky bit is the recursive case. There, we immediately call splitListCPS with x. The function \ r₁ -> ... that tells splitListCPS what to do when it's done--in this case, call splitListCPS with the next argument (y). Finally, once we have both results, we just combine the results and pass that into the original continuation (cont). So at the end, we get the same result we had originally (namely splitList x ++ splitList y) but instead of returning it, we just use the continuation.

Also, if you look through the above code, you'll note that all the recursive calls are in tail position. At each step, our last action is always either a recursive call or using the continuation. With a clever compiler, this sort of code can actually be fairly efficient.

In a certain sense, this technique is actually similar to what you did for fib; however, instead of maintaining an accumulator value we sort of maintain an accumulator of the computation we're doing.

like image 74
Tikhon Jelvis Avatar answered Nov 15 '22 07:11

Tikhon Jelvis


You don't generally want tail-recursion in Haskell. What you do want, is productive corecursion (see also this), describing what in SICP is called an iterative process.

You can fix the type inconsistency in your function by enclosing initial input in a list. In your example

[1, 2, 3, 4, 5] -> [[1, 3, 4], [2, 5]] -> [[1, 3], [4], [2], [5]]

only the first arrow is inconsistent, so change it into

[[1, 2, 3, 4, 5]] -> [[1, 3, 4], [2, 5]] -> [[1, 3], [4], [2], [5]]

which illustrates the process of iteratively applying concatMap splitList1, where

   splitList1 xs 
      | null $ drop 1 xs = [xs]
      | magic a b > 0    = [a,b]    -- (B)
      | otherwise        = [xs]
     where (a,b) = splitSomeHow xs

You want to stop if no (B) case was fired at a certain iteration.

(edit: removed the intermediate version)

But it is much better to produce the portions of the output that are ready, as soon as possible:

splitList :: [Int] -> [[Int]]
splitList xs = g [xs]   -- explicate the stack
  where
    g []                  = []
    g (xs : t)
       | null $ drop 1 xs = xs : g t
       | magic a b > 0    = g (a : b : t)
       | otherwise        = xs : g t
     where (a,b) = splitSomeHow xs 
           -- magic a b = 1
           -- splitSomeHow = splitAt 2

Don't forget to compile with -O2 flag.

like image 31
Will Ness Avatar answered Nov 15 '22 09:11

Will Ness