Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using dynamic programming in Haskell? [Warning: ProjectEuler 31 solution inside]

In solving projecteuler.net's problem #31 [SPOILERS AHEAD] (counting the number of ways to make 2£ with the British coins), I wanted to use dynamic programming. I started with OCaml, and wrote the short and very efficient following programming:

open Num

let make_dyn_table amount coins =
  let t = Array.make_matrix (Array.length coins) (amount+1) (Int 1) in
  for i = 1 to (Array.length t) - 1 do
    for j = 0 to amount do
      if j < coins.(i) then
        t.(i).(j) <- t.(i-1).(j)
      else
        t.(i).(j) <- t.(i-1).(j) +/ t.(i).(j - coins.(i))
    done
  done;
  t

let _ =
  let t = make_dyn_table 200 [|1;2;5;10;20;50;100;200|] in
  let last_row = Array.length t - 1 in
  let last_col = Array.length t.(last_row) - 1 in
  Printf.printf "%s\n" (string_of_num (t.(last_row).(last_col)))

This executes in ~8ms on my laptop. If I increase the amount from 200 pence to one million, the program still finds an answer in less than two seconds.

I translated the program to Haskell (which was definitely not fun in itself), and though it terminates with the right answer for 200 pence, if I increase that number to 10000, my laptop comes to a screeching halt (lots of thrashing). Here's the code:

import Data.Array

createDynTable :: Int -> Array Int Int -> Array (Int, Int) Int
createDynTable amount coins =
    let numCoins = (snd . bounds) coins
        t = array ((0, 0), (numCoins, amount))
            [((i, j), 1) | i <- [0 .. numCoins], j <- [0 .. amount]]
    in t

populateDynTable :: Array (Int, Int) Int -> Array Int Int -> Array (Int, Int) Int
populateDynTable t coins =
    go t 1 0
        where go t i j
                 | i > maxX = t
                 | j > maxY = go t (i+1) 0
                 | j < coins ! i = go (t // [((i, j), t ! (i-1, j))]) i (j+1)
                 | otherwise = go (t // [((i, j), t!(i-1,j) + t!(i, j - coins!i))]) i (j+1)
              ((_, _), (maxX, maxY)) = bounds t

changeCombinations amount coins =
    let coinsArray = listArray (0, length coins - 1) coins
        dynTable = createDynTable amount coinsArray
        dynTable' = populateDynTable dynTable coinsArray
        ((_, _), (i, j)) = bounds dynTable
    in
      dynTable' ! (i, j)

main =
    print $ changeCombinations 200 [1,2,5,10,20,50,100,200]

I'd love to hear from somebody who knows Haskell well why the performance of this solution is so bad.

like image 926
gnuvince Avatar asked Dec 14 '12 01:12

gnuvince


1 Answers

Haskell is pure. The purity means that values are immutable, and thus in the step

j < coins ! i = go (t // [((i, j), t ! (i-1, j))]) i (j+1)

you create an entire new array for each entry you update. That's already very expensive for a small amount like £2, but it becomes utterly obscene for an amount of £100.

Furthermore, the arrays are boxed, that means they contain pointers to the entries, which worsens locality, uses more storage, and allows thunks to be built up that are also slower to evaluate when they finally are forced.

The used algorithm depends on a mutable data structure for its efficiency, but the mutability is confined to the computation, so we can use what is intended to allow safely shielded computations with temporarily mutable data, the ST state transformer monad family, and the associated [unboxed, for efficiency] arrays.

Give me half an hour or so to translate the algorithm into code using STUArrays, and you'll get a Haskell version that is not too ugly, and ought to perform comparably to the O'Caml version (some more or less constant factor is expected for the difference, whether it's larger or smaller than 1, I don't know).

Here it is:

module Main (main) where

import System.Environment (getArgs)

import Data.Array.ST
import Control.Monad.ST
import Data.Array.Unboxed

standardCoins :: [Int]
standardCoins = [1,2,5,10,20,50,100,200]

changeCombinations :: Int -> [Int] -> Int
changeCombinations amount coins = runST $ do
    let coinBound = length coins - 1
        coinsArray :: UArray Int Int
        coinsArray = listArray (0, coinBound) coins
    table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STUArray s (Int,Int) Int)
    let go i j
            | i > coinBound = readArray table (coinBound,amount)
            | j > amount   = go (i+1) 0
            | j < coinsArray ! i = do
                v <- readArray table (i-1,j)
                writeArray table (i,j) v
                go i (j+1)
            | otherwise = do
                v <- readArray table (i-1,j)
                w <- readArray table (i, j - coinsArray!i)
                writeArray table (i,j) (v+w)
                go i (j+1)
    go 1 0

main :: IO ()
main = do
    args <- getArgs
    let amount = case args of
                   a:_ -> read a
                   _   -> 200
    print $ changeCombinations amount standardCoins

runs in not too shabby time,

$ time ./mutArr
73682

real    0m0.002s
user    0m0.000s
sys     0m0.001s
$ time ./mutArr 1000000
986687212143813985

real    0m0.439s
user    0m0.128s
sys     0m0.310s

and uses checked array accesses, using unchecked accesses, the time could be somewhat reduced.


Ah, I just learned that your O'Caml code uses arbitrary precision integers, so using Int in Haskell puts O'Caml at an unfair disadvantage. The changes necessary to calculate the results with arbitrary precision Integers are minmal,

$ diff mutArr.hs mutArrIgr.hs
12c12
< changeCombinations :: Int -> [Int] -> Int
---
> changeCombinations :: Int -> [Int] -> Integer
17c17
<     table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STUArray s (Int,Int) Int)
---
>     table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STArray s (Int,Int) Integer)
28c28
<                 writeArray table (i,j) (v+w)
---
>                 writeArray table (i,j) $! (v+w)

only two type signatures needed to be adapted - the array necessarily becomes boxed, so we need to make sure we're not writing thunks to the array in line 28, and

$ time ./mutArrIgr 
73682

real    0m0.002s
user    0m0.000s
sys     0m0.002s
$ time ./mutArrIgr 1000000
99341140660285639188927260001

real    0m1.314s
user    0m1.157s
sys     0m0.156s

the computation with the large result that overflowed for Ints takes noticeably longer, but as expected comparable to the O'Caml.


Spending some time understanding the O'Caml, I can offer a closer, a bit shorter, and arguably nicer translation:

module Main (main) where

import System.Environment (getArgs)

import Data.Array.ST
import Control.Monad.ST
import Data.Array.Unboxed
import Control.Monad (forM_)

standardCoins :: [Int]
standardCoins = [1,2,5,10,20,50,100,200]

changeCombinations :: Int -> [Int] -> Integer
changeCombinations amount coins = runST $ do
    let coinBound = length coins - 1
        coinsArray :: UArray Int Int
        coinsArray = listArray (0, coinBound) coins
    table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STArray s (Int,Int) Integer)
    forM_ [1 .. coinBound] $ \i ->
        forM_ [0 .. amount] $ \j ->
            if j < coinsArray!i
              then do
                  v <- readArray table (i-1,j)
                  writeArray table (i,j) v
              else do
                v <- readArray table (i-1,j)
                w <- readArray table (i, j - coinsArray!i)
                writeArray table (i,j) $! (v+w)
    readArray table (coinBound,amount)

main :: IO ()
main = do
    args <- getArgs
    let amount = case args of
                   a:_ -> read a
                   _   -> 200
    print $ changeCombinations amount standardCoins

that runs about equally fast:

$ time ./mutArrIgrM 1000000
99341140660285639188927260001

real    0m1.440s
user    0m1.273s
sys     0m0.164s
like image 69
Daniel Fischer Avatar answered Nov 03 '22 15:11

Daniel Fischer