Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing a memoization function in Haskell

I'm fairly new to Haskell, and I'm trying to implement a basic memoization function which uses a Data.Map to store computed values. My example is for Project Euler Problem 15, which involves computing the number of possible paths from 1 corner to the other in a 20x20 grid.

This is what I have so far. I haven't tried compiling yet because I know it won't compile. I'll explain below.

import qualified Data.Map as Map

main = print getProblem15Value

getProblem15Value :: Integer
getProblem15Value = getNumberOfPaths 20 20

getNumberOfPaths :: Integer -> Integer -> Integer
getNumberOfPaths x y = memoize getNumberOfPaths' (x,y)
where getNumberOfPaths' mem (0,_) = 1
      getNumberOfPaths' mem (_,0) = 1
      getNumberOfPaths' mem (x,y) = (mem (x-1,y)) + (mem (x,y-1))

memoize :: ((a -> b) -> a -> b) -> a -> b
memoize func x = fst $ memoize' Map.Empty func x
    where memoize' map func' x' = case (Map.lookup x' map) of (Just y) -> (y, map)
                                                              Nothing -> (y', map'')
           where y' = func' mem x'
                 mem x'' = y''
                 (y'', map') = memoize' map func' x''
                 map'' = Map.insert x' y' map'

So basically, the way I have this structured is that memoize is a combinator (by my understanding). The memoization works because memoize provides a function (in this case getNumberOfPaths') with a function to call (mem) for recursion, instead of having getNumberOfPaths' call itself, which would remove the memoization after the first iteration.

My implementation of memoize takes a function (in this case getNumberOfPaths') and an initial value (in this case a tuple (x,y) representing the number of grid cell distances from the other corner of the grid). It calls memoize' which has the same structure, but includes an empty Map to hold values, and returns a tuple containing the return value and a new computed Map. memoize' does a map lookup and returns the value and the original map if there is a value present. If there is no value present, it returns the computed value and a new map.

This is where my algorithm breaks down. To compute the new value, I call func' (getNumberOfPaths') with mem and x'. mem simply returns y'', where y'' is contained in the result of calling memoize' again. memoize' also returns a new map, to which we then add the new value and use as the return value of memoize'.

The issue here is that the line (y'', map') = memoize' map func' x'' should be under mem because it's dependent on x'', which is a parameter of mem. I can certainly do that, but then I will lose the map' value, which I need because it contains memoized values from intermediate computations. However, I don't want to introduce the Map into the return value of mem because then the function passed to memoize will have to handle the Map.

Sorry if that sounded confusing. A lot of this ultra-high-order functional stuff is confusing to me.

I'm sure that there is a way to do this. What I want is a generic memoize function that allows recursive calling exactly like in the definition of getNumberOfPaths, where the computation logic doesn't have to care exactly how the memoization is done.

like image 546
jchitel Avatar asked May 18 '26 05:05

jchitel


1 Answers

Provided your inputs are small enough, one thing you can do is allocate the memo table as an Array instead of a Map, containing all the results ahead of time, but calculated lazily:

import Data.Array ((!), array)

numPaths :: Integer -> Integer -> Integer
numPaths w h = get (w - 1) (h - 1)
  where

    table = array (0, w * h)
      [ (y * w + x, go x y)
      | y <- [0 .. h - 1]
      , x <- [0 .. w - 1]
      ]

    get x y = table ! fromInteger (y * w + x)

    go 0 _ = 1
    go _ 0 = 1
    go x y = get (x - 1) y + get x (y - 1)

You can also split this into separate functions if you prefer:

numPaths w h = withTable w h go (w - 1) (h - 1)
  where
    go mem 0 _ = 1
    go mem _ 0 = 1
    go mem x y = mem (x - 1) y + mem x (y - 1)

withTable w h f = f'
  where
    f' = f get
    get x y = table ! fromInteger (y * w + x)
    table = makeTable w h f'

makeTable w h f = array (0, w * h)
  [ (y * w + x, f x y)
  | y <- [0 .. w - 1]
  , x <- [0 .. h - 1]
  ]

And I won’t spoil it for you, but there’s also a non-recursive formula for the answer.

like image 122
Jon Purdy Avatar answered May 21 '26 06:05

Jon Purdy