Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Caching in Haskell and explicit parallelism

I'm currently trying to optimize my solution to problem 14 at Projet Euler. I really enjoy Haskell and I think it's a very good fit for these kind of problems, here's three different solutions I've tried:

import Data.List (unfoldr, maximumBy)
import Data.Maybe (fromJust, isNothing)
import Data.Ord (comparing)
import Control.Parallel

next :: Integer -> Maybe (Integer)
next 1 = Nothing
next n
  | even n = Just (div n 2)
  | odd n  = Just (3 * n + 1)

get_sequence :: Integer -> [Integer]
get_sequence n = n : unfoldr (pack . next) n
  where pack n = if isNothing n then Nothing else Just (fromJust n, fromJust n)

get_sequence_length :: Integer -> Integer
get_sequence_length n
    | isNothing (next n) = 1
    | otherwise = 1 + (get_sequence_length $ fromJust (next n))

-- 8 seconds
main1 = print $ maximumBy (comparing length) $ map get_sequence [1..1000000]

-- 5 seconds
main2 = print $ maximum $ map (\n -> (get_sequence_length n, n)) [1..1000000]

-- Never finishes
main3 = print solution
  where
    s1 = maximumBy (comparing length) $ map get_sequence [1..500000]
    s2 = maximumBy (comparing length) $ map get_sequence [500001..10000000]
    solution = (s1 `par` s2) `pseq` max s1 s2

Now if you look at the actual problem there's a lot of potential for caching, as most new sequences will contain subsequences that have already been calculated before.

For comparison, I wrote a version in C too:
Running time with caching: 0.03 seconds
Running time without caching: 0.3 seconds

That's just insane! Sure, caching reduced the time by a factor of 10, but even without caching it's still at least 17 times faster than my Haskell code.

What's wrong with my code? Why doesn't Haskell cache the function calls for me? As the functions are pure caching shouldn't caching be trivial, only a matter of available memory?

What's the problem with my third parallel version? Why doesn't it finish?

Regarding Haskell as a language, does the compiler automatically parallellize some code (folds, maps etc), or does it always have to be done explicitly using Control.Parallel?

Edit: I stumbled upon this similar question. They mentioned that his function wasn't tail-recursive. Is my get_sequence_length tail recursive? If not how can I make it so?

Edit2:
To Daniel:
Thanks a lot for the reply, really awesome. I've been playing around with your improvements and I've found some really bad gotchas.

I'm running the tests on Windws 7 (64-bit), 3.3 GHZ Quad core with 8GB RAM.
The first thing I did was as you say replace all Integer with Int, but whenever I ran any of the mains I ran out of memory, even with +RTS kSize -RTS set ridiciously high.

Eventually I found this (stackoverflow is awesome...), which means that since all Haskell programs on Windows are run as 32-bit, the Ints were overflowing causing infinite recursion, just wow...

I ran the tests in a Linux virtual machine (with the 64-bit ghc) instead and got similar results.

like image 882
user1599468 Avatar asked Aug 15 '12 00:08

user1599468


1 Answers

Alright, let's start from the top. First important thing is to give the exact command line you're using to compile and run; for my answer, I'll use this line for the timings of all programs:

ghc -O2 -threaded -rtsopts test && time ./test +RTS -N

Next up: since timings vary greatly from machine to machine, we'll give some baseline timings for my machine and your programs. Here's the output of uname -a for my computer:

Linux sorghum 3.4.4-2-ARCH #1 SMP PREEMPT Sun Jun 24 18:59:47 CEST 2012 x86_64 Intel(R) Core(TM)2 Quad CPU Q6600 @ 2.40GHz GenuineIntel GNU/Linux

The highlights are: quad-core, 2.4GHz, 64-bit.

Using main1: 30.42s user 2.61s system 149% cpu 22.025 total
Using main2: 21.42s user 1.18s system 129% cpu 17.416 total
Using main3: 22.71s user 2.02s system 220% cpu 11.237 total

Actually, I modified main3 in two ways: first, by removing one of the zeros from the end of the range in s2, and second, by changing max s1 s2 to maximumBy (comparing length) [s1, s2], since the former only accidentally computes the right answer. =)

I'll now focus on serial speed. (To answer one of your direct questions: no, GHC does not automatically parallelize or memoize your programs. Both of those things have overheads that are very difficult to estimate, and consequently it's very difficult to decide when doing them will be beneficial. I have no idea why even the serial solutions in this answer are getting >100% CPU utilization; perhaps some garbage collection is happening in another thread or some such thing.) We'll start from main2, since it was the faster of the two serial implementations. The cheapest way to get a little boost is to change all the type signatures from Integer to Int:

Using Int: 11.17s user 0.50s system 129% cpu 8.986 total (about twice as fast)

The next boost comes from reducing allocation in the inner loop (eliminating the intermediate Maybe values).

import Data.List
import Data.Ord

get_sequence_length :: Int -> Int
get_sequence_length 1 = 1
get_sequence_length n
    | even n = 1 + get_sequence_length (n `div` 2)
    | odd  n = 1 + get_sequence_length (3 * n + 1)

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

Using this: 4.84s user 0.03s system 101% cpu 4.777 total

The next boost comes from using faster operations than even and div:

import Data.Bits
import Data.List
import Data.Ord

even' n = n .&. 1 == 0

get_sequence_length :: Int -> Int
get_sequence_length 1 = 1
get_sequence_length n = 1 + get_sequence_length next where
    next = if even' n then n `quot` 2 else 3 * n + 1

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

Using this: 1.27s user 0.03s system 105% cpu 1.232 total

For those following along at home, this is about 17 times faster than the main2 that we started with -- a competitive improvement with switching to C.

For memoization, there's a few choices. The simplest is to use a pre-existing package like data-memocombinators to create an immutable array and read from it. The timings are fairly sensitive to choosing a good size for this array; for this problem, I found 50000 to be a pretty good upper bound.

import Data.Bits
import Data.MemoCombinators
import Data.List
import Data.Ord

even' n = n .&. 1 == 0

pre_length :: (Int -> Int) -> (Int -> Int)
pre_length f 1 = 1
pre_length f n = 1 + f next where
    next = if even' n then n `quot` 2 else 3 * n + 1

get_sequence_length :: Int -> Int
get_sequence_length = arrayRange (1,50000) (pre_length get_sequence_length)

lengths :: [(Int,Int)]
lengths = map (\n -> (get_sequence_length n, n)) [1..1000000]

main = print (maximumBy (comparing fst) lengths)

With this: 0.53s user 0.10s system 149% cpu 0.421 total

The fastest of all is to use a mutable, unboxed array for the memoization bit. It's much less idiomatic, but it's bare-metal speed. The speed is much less sensitive on the size of this array, so long as the array is about as large as the biggest thing you want the answer for.

import Control.Monad
import Control.Monad.ST
import Data.Array.Base
import Data.Array.ST
import Data.Bits
import Data.List
import Data.Ord

even' n = n .&. 1 == 0
next  n = if even' n then n `quot` 2 else 3 * n + 1

get_sequence_length :: STUArray s Int Int -> Int -> ST s Int
get_sequence_length arr n = do
    bounds@(lo,hi) <- getBounds arr
    if not (inRange bounds n) then (+1) `fmap` get_sequence_length arr (next n) else do
        let ix = n-lo
        v <- unsafeRead arr ix
        if v > 0 then return v else do
            v' <- get_sequence_length arr (next n)
            unsafeWrite arr ix (v'+1)
            return (v'+1)

maxLength :: (Int,Int)
maxLength = runST $ do
    arr <- newArray (1,1000000) 0
    writeArray arr 1 1
    loop arr 1 1 1000000
    where
    loop arr n len 1  = return (n,len)
    loop arr n len n' = do
        len' <- get_sequence_length arr n'
        if len' > len then loop arr n' len' (n'-1) else loop arr n len (n'-1)

main = print maxLength

With this: 0.16s user 0.02s system 138% cpu 0.130 total (which is competitive with the memoized C version)

like image 167
Daniel Wagner Avatar answered Sep 25 '22 06:09

Daniel Wagner