Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimizing a Haskell program

I started looking at Haskell yesterday with the goal of actually learning it. I've written some trivial programs with it in programming language courses, but none of them really cared about efficiency. I'm trying to understand how to improve the running time of the following program.

My program solves the following toy problem (I know it's simple to compute the answer by hand if you know what a factorial is, but I'm doing it the brute force way with a successor function):

http://projecteuler.net/problem=24

My algorithm for the successor function for lexicographic ordering given a list of finite length is the following:

  1. If the list is already in decreasing order, then we have the maximal element in the lexicographic ordering, so there's no successor.

  2. Given a list h : t, either t is maximal in the lexicographic ordering or it's not. In the latter case compute the successor of t. In the former case proceed as follows.

  3. Pick the smallest element d in t larger than h.

  4. Replace d with h in t giving a new list t'. The next element in the ordering is d : (sort t')

My program that implements this is the following (lots of these function are probably in the standard library):

max_list :: (Ord a) => [a] -> a
max_list []     = error "Empty list has no maximum!"
max_list (h:[]) = h
max_list (h:t)  = max h (max_list t)

min_list :: (Ord a) => [a] -> a
min_list []     = error "Empty list has no minimum!"
min_list (h:[]) = h
min_list (h:t)  = min h (min_list t)

-- replaces first occurrence of x in list with y
replace :: (Eq a) => a -> a -> [a] -> [a]
replace _ _ []  = []
replace x y (h:t)
    | h == x    = y : t
    | otherwise = h : (replace x y t)

-- sort in increasing order
sort_list :: (Ord a) => [a] -> [a]
sort_list []    = []
sort_list (h:t) = (sort_list (filter (\x -> x <= h) t))
               ++ [h]
               ++ (sort_list (filter (\x -> x > h) t))

-- checks if list is in descending order
descending :: (Ord a) => [a] -> Bool
descending []     = True
descending (h:[]) = True
descending (h:t)
    | h > (max_list t) = descending t
    | otherwise        = False

succ_list :: (Ord a) => [a] -> [a]
succ_list []      = []
succ_list (h:[])  = [h]
succ_list (h:t)
    | descending (h:t)   = (h:t)
    | not (descending t) = h : succ_list t
    | otherwise = next_h : sort_list (replace next_h h t)
    where next_h = min_list (filter (\x -> x > h) t)

-- apply function n times
apply_times :: (Integral n) => n -> (a -> a) -> a -> a
apply_times n _ a
    | n <= 0      = a
apply_times n f a = apply_times (n-1) f (f a)

main = putStrLn (show (apply_times 999999 succ_list [0,1,2,3,4,5,6,7,8,9]))

Now the actual question. After noticing that my program took a while to run, I wrote an equivalent C program for comparison. My guess is that the lazy evaluation of Haskell causes the apply_times function to build a huge list in memory before it actually starts evaluating the result. I had to increase the runtime stack size for it to run. Since efficient Haskell programming seems to be about tricks, are there any nice tricks that could be used to minimize memory consumption? What about ways to minimize copying and garbage collection, since lists keep getting created over and over while a C implementation would do everything in place.

Since Haskell is supposedly efficient, I guess there has to be a way? One cool thing that I have to say about Haskell though is that the program worked correctly the first time it compiled, so that part of the language does seem to fill it's promise.

like image 464
Edvard Fagerholm Avatar asked Jan 02 '13 04:01

Edvard Fagerholm


1 Answers

lots of these function are probably in the standard library

Indeed. If you import Data.List, that makes sort available, maximum and minimum are available from the Prelude. The sort from Data.List is all in all more efficient than the quasi-quicksort, in particular since you have a lot of sorted chunks in the lists here.

descending :: (Ord a) => [a] -> Bool
descending []     = True
descending (h:[]) = True
descending (h:t)
    | h > (max_list t) = descending t
    | otherwise        = False

is inefficient - O(n²) - since it traverses the entire left tail in each step, although if the list is descending, the maximum of the tail must be its head. But that has a nice consequence here. It prevents the build-up of thunks, since the first guard of the third equation of succ_list forces the list to be completely evaluated. However, that could be done more efficiently with an explicit forcing of the list once.

descending (h:t@(ht:_)) = h > ht && descending t

would make it linear. That

After noticing that my program took a while to run, I wrote an equivalent C program for comparison.

That would be unusual. Few would even go so far to use a linked list in C, implementing lazy evaluation on top of that would be quite an undertaking.

Writing an equivalent programme in C would be extremely unidiomatic. In C, the natural way to implement the algorithm would use an array and in-place mutation. That is automatically much more efficient here.

My guess is that the lazy evaluation of Haskell causes the apply_times function to build a huge list in memory before it actually starts evaluating the result.

Not quite, what it builds is a huge thunk,

apply_times 999999 succ_list [0,1,2,3,4,5,6,7,8,9]
~> apply_times 999998 succ_list (succ_list [0 .. 9])
~> apply_times 999997 succ_list (succ_list (succ_list [0 .. 9]))
~> apply_times 999996 succ_list (succ_list (succ_list (succ_list [0 .. 9])))
...
succ_list (succ_list (succ_list ... (succ_list [0 .. 9])...))

and, after that thunk has been built, it must be evaluated. To evaluate the outermost call, the next must be evaluated far enough to find out which pattern matches in the outermost call. So the outermost call is pushed on a stack, and the next call is started to be evaluated. For that, it must be determined which pattern matches, so part of the result of the third call is needed. Thus the second call is pushed on the stack ... . At the end, you have 999998 calls on the stack and start to evaluate the innermost call. Then you play a bit of ping-pong between each call and the next outer call (at least, the dependencies might spread a bit further) while bubbling up and popping calls from the stack.

are there any nice tricks that could be used to minimize memory consumption

Yes, force the intermediate lists to be evaluated before they become the argument of apply_times. You need complete evaluation here, so the vanilla seq is not good enough

import Control.DeepSeq

apply_times' :: (NFData a, Integral n) => n -> (a -> a) -> a -> a
apply_times' 0 _ x = x
apply_times' k f x = apply_times' (k-1) f $!! f x

that prevents the build-up of thunks, and thus you don't need more memory than for a few short lists constructed in succ_list, and the counter.

What about ways to minimize copying and garbage collection, since lists keep getting created over and over while a C implementation would do everything in place.

Right, that would still allocate (and garbage collect) a lot. Now, GHC is very good in allocating and garbage collecting short-lived data (on my box, it can easily allocate at a rate of 2GB per MUT second without being slow), but still, not allocating all those lists would be faster.

So, if you want to push it, use in-place mutation. Work on an

STUArray s Int Int

or an unboxed mutable Vector (I prefer the interface provided by the array package, but most prefer the vector interface; in terms of performance, the vector package has a lot of optimisations built-in for you, if you use the array package, you have to write the fast code yourself, but well-written code performs equal for all practical purposes).


I've done a bit of testing now. I have not tested the original lazy apply_times, only the one deepseqing each application of f, and have fixed the type of all involved entities as Int.

With that set-up, replacing sort_list with Data:list.sort reduced the running time from 1.82 seconds to 1.65 (but increased the number of allocated bytes). Not too much of a difference, but the lists are not long enough to make the bad cases for the quasi-quicksort really bite.

The big difference then comes from changing descending as proposed, that brought the time down to 0.48 seconds, Alloc rate 2,170,566,037 bytes per MUT second, 0.01 seconds GC time (and then using sort_list instead of sort brings the time up to 0.58 seconds).

Replacing the sorting of the ending segment of the list with a simpler reverse - the algorithm guarantees that it is sorted in descending order when it is sorted - brings down the time to 0.43 seconds.

A fairly direct translation of the algorithm to use unboxed mutable arrays,

{-# LANGUAGE BangPatterns #-}
module Main (main) where

import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Control.Monad (when, replicateM_)

sortPart :: STUArray s Int Int -> Int -> Int -> ST s ()
sortPart a lo hi
   | lo < hi   = do
       let lscan !p h i
               | i < h = do
                   v <- unsafeRead a i
                   if p < v then return i else lscan p h (i+1)
               | otherwise = return i
           rscan !p l i
               | l < i = do
                   v <- unsafeRead a i
                   if v < p then return i else rscan p l (i-1)
               | otherwise = return i
           swap i j = do
               v <- unsafeRead a i
               unsafeRead a j >>= unsafeWrite a i
               unsafeWrite a j v
           sloop !p l h
               | l < h = do
                   l1 <- lscan p h l
                   h1 <- rscan p l1 h
                   if (l1 < h1) then (swap l1 h1 >> sloop p l1 h1) else return l1
               | otherwise = return l
       piv <- unsafeRead a hi
       i <- sloop piv lo hi
       swap i hi
       sortPart a lo (i-1)
       sortPart a (i+1) hi
   | otherwise = return ()

descending :: STUArray s Int Int -> Int -> Int -> ST s Bool
descending arr lo hi
    | lo < hi   = do
        let check i !v
                | hi < i    = return True
                | otherwise = do
                    w <- unsafeRead arr i
                    if w < v
                      then check (i+1) w
                      else return False
        x <- unsafeRead arr lo
        check (lo+1) x
    | otherwise = return True

findAndReplace :: STUArray s Int Int -> Int -> Int -> ST s ()
findAndReplace arr lo hi
    | lo < hi   = do
        x <- unsafeRead arr lo
        let go !mi !mv i
                | hi < i    = when (lo < mi) $ unsafeWrite arr mi x >> unsafeWrite arr lo mv
                | otherwise = do
                    w <- unsafeRead arr i
                    if x < w && w < mv
                      then go i w (i+1)
                      else go mi mv (i+1)
            look i
                | hi < i    = return ()
                | otherwise = do
                    w <- unsafeRead arr i
                    if x < w
                      then go i w (i+1)
                      else look (i+1)
        look (lo+1)
    | otherwise = return ()

succArr :: STUArray s Int Int -> Int -> Int -> ST s ()
succArr arr lo hi
    | lo < hi   = do
        end <- descending arr lo hi
        if end
          then return ()
          else do
              needSwap <- descending arr (lo+1) hi
              if needSwap
                then do
                    findAndReplace arr lo hi
                    sortPart arr (lo+1) hi
                else succArr arr (lo+1) hi
    | otherwise = return ()

solution :: [Int]
solution = runST $ do
    arr <- newListArray (0,9) [0 .. 9]
    replicateM_ 999999 $ succArr arr 0 9
    getElems arr

main :: IO ()
main = print solution

completes in 0.15 seconds. Replacing the sorting with a simpler reversing of the part brings it down to 0.11.

Splitting the algorithm into small top-level functions that each do one task makes it more readable, but that comes at a price. More parameters need to be passed between the functions, consequently not all can be passed in registers, and some of the passed parameters - the array bounds and element count - are not used at all, so that's dead weight being passed. Making all other functions local functions in solution reduces the overall allocation and running time somewhat (0.13 seconds with sorting, 0.09 with reversing), since now only the necessary parameters need to be passed.

Deviating further from the given algorithm and making it work back to front,

module Main (main) where

import Data.Array.ST
import Data.Array.Base
import Data.Array.Unboxed
import Control.Monad.ST
import Control.Monad (when)
import Data.Bits

lexPerm :: Int -> Int -> [Int]
lexPerm idx num = elems (runSTUArray $ do
    arr <- unsafeNewArray_ (0,num)
    let fill i
            | num < i   = return ()
            | otherwise = unsafeWrite arr i i >> fill (i+1)
        swap i j = do
            x <- unsafeRead arr i
            y <- unsafeRead arr j
            unsafeWrite arr j x
            unsafeWrite arr i y
        flop i j
            | i < j     = do
                swap i j
                flop (i+1) (j-1)
            | otherwise = return ()
        binsearch v a b = go a b
          where
            go i j
              | i < j     = do
                let m = (i+j+1) `unsafeShiftR` 1
                w <- unsafeRead arr m
                if w < v
                  then go i (m-1)
                  else go m j
              | otherwise = swap a i
        upstep k j
            | k < 1     = return ()
            | j == num-1 = unsafeRead arr num >>= flip (back k) (num-1)
            | otherwise  = nextP k (num-1)
        back k v i
            | i < 0     = return ()
            | otherwise = do
                w <- unsafeRead arr i
                if w < v
                  then nextP k i
                  else back k w (i-1)
        nextP k up
            | k < 1 || up < 0   = return ()
            | otherwise = do
                v <- unsafeRead arr up
                binsearch v up num
                flop (up+1) num
                upstep (k-1) up
    fill 0
    nextP (idx-1) (num-1)
    return arr)

main :: IO ()
main = print $ lexPerm 1000000 9

we can complete the task in 0.02 seconds.

The clever algorithm alluded to in the question, however, solves the task with far less code in much less time.

like image 188
Daniel Fischer Avatar answered Oct 12 '22 04:10

Daniel Fischer