Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tail recursion recognition

Tags:

haskell

ghc

I'm trying to learn Haskell and I stumbled upon the following:

myAdd (x:xs) = x + myAdd xs
myAdd null = 0

f = let n = 10000000 in myAdd [1 .. n]

main = do
 putStrLn (show f)

When compiling with GHC, this yields a stack overflow. As a C/C++ programmer, I would have expected the compiler to do tail call optimization.

I don't like that I would have to "help" the compiler in simple cases like these, but what options are there? I think it is reasonable to require that the calculation given above be done without using O(n) memory, and without deferring to specialized functions.

If I cannot state my problem naturally (even on a toy problem such as this), and expect reasonable performance in terms of time & space, much of the appeal of Haskell would be lost.

like image 582
reddish Avatar asked Dec 29 '11 15:12

reddish


2 Answers

Firstly, make sure you're compiling with -O2. It makes a lot of performance problems just go away :)

The first problem I can see is that null is just a variable name there. You want []. It's equivalent here because the only options are x:xs and [], but it won't always be.

The issue here is simple: when you call sum [1,2,3,4], it looks like this:

1 + (2 + (3 + (4 + 0)))

without ever reducing any of these additions to a number, because of Haskell's non-strict semantics. The solution is simple:

myAdd = myAdd' 0
  where myAdd' !total [] = total
        myAdd' !total (x:xs) = myAdd' (total + x) xs

(You'll need {-# LANGUAGE BangPatterns #-} at the top of your source file to compile this.)

This accumulates the addition in another parameter, and is actually tail recursive (yours isn't; + is in tail position rather than myAdd). But in fact, it's not quite tail recursion we care about in Haskell; that distinction is mainly relevant in strict languages. The secret here is the bang pattern on total: it forces it to be evaluated every time myAdd' is called, so no unevaluated additions build up, and it runs in constant space. In this case, GHC can actually figure this out with -O2 thanks to its strictness analysis, but I think it's usually best to be explicit about what you want strict and what you don't.

Note that if addition was lazy, your myAdd definition would work fine; the problem is that you're doing a lazy traversal of the list with a strict operation, which ends up causing the stack overflow. This mostly comes up with arithmetic, which is strict for the standard numeric types (Int, Integer, Float, Double, etc.).

This is quite ugly, and it would be a pain to write something like this every time we want to write a strict fold. Thankfully, Haskell has an abstraction ready for this!

myAdd = foldl' (+) 0

(You'll need to add import Data.List to compile this.)

foldl' (+) 0 [a, b, c, d] is just like (((0 + a) + b) + c) + d, except that at each application of (+) (which is how we refer to the binary operator + as a function value), the value is forced to be evaluated. The resulting code is cleaner, faster, and easier to read (once you know how the list folds work, you can understand any definition written in terms of them easier than a recursive definition).

Basically, the problem here is not that the compiler can't figure out how to make your program efficient — it's that making it as efficient as you like could change its semantics, which an optimisation should never do. Haskell's non-strict semantics certainly pose a learning curve to programmers in more "traditional" languages like C, but it gets easier over time, and once you see the power and abstraction that Haskell's non-strictness offers, you'll never want to go back :)

like image 167
ehird Avatar answered Sep 18 '22 06:09

ehird


Expanding the example ehird hinted at in the comments:

data Peano = Z | S Peano
  deriving (Eq, Show)

instance Ord Peano where
    compare (S a) (S b) = compare a b
    compare Z Z = EQ
    compare Z _ = LT
    compare _ _ = GT

instance Num Peano where
    Z + n = n
    (S a) + n = S (a + n)
    -- omit others
    fromInteger 0 = Z
    fromInteger n
        | n < 0 = error "Peano: fromInteger requires non-negative argument"
        | otherwise = S (fromInteger (n-1))

instance Enum Peano where
    succ = S
    pred (S a) = a
    pred _ = error "Peano: no predecessor"
    toEnum n
        | n < 0 = error "toEnum: invalid argument"
        | otherwise = fromInteger (toInteger n)
    fromEnum Z = 0
    fromEnum (S a) = 1 + fromEnum a
    enumFrom = iterate S
    enumFromTo a b = takeWhile (<= b) $ enumFrom a
    -- omit others

infinity :: Peano
infinity = S infinity

result :: Bool
result = 3 < myAdd [1 .. infinity]

result is True by the definition of myAdd, but if the compiler transformed into a tail-recursive loop, it wouldn't terminate. So that transformation is not only a change in efficiency, but also in semantics, hence a compiler must not do it.

like image 27
Daniel Fischer Avatar answered Sep 18 '22 06:09

Daniel Fischer