Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Memoize multi-dimensional recursive solutions in haskell

I was solving a recursive problem in haskell, although I could get the solution I would like to cache outputs of sub problems since has over lapping sub-problem property.

The question is, given a grid of dimension n*m, and an integer k, how many ways are there to reach the gird (n, m) from (1, 1) with not more than k change of direction?

Here is the code without of memoization

paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
    | i > n || j > m || k < 0 = 0
    | i == n && j == m = 1
    | dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2        -- is in grid (1,1)
    | dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2    -- down was the direction took to reach here
    | dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2    -- right was the direction took to reach here 
    | otherwise = -1

Here the dependent variables are i, j, k, dir. In languages like C++/Java a 4-d DP array could have been used (dp[n][m][k][3], in Haskell I can't find a way to implement that.

like image 848
Adeeb HS Avatar asked Feb 03 '23 13:02

Adeeb HS


2 Answers

"Tying the knot" is a well-known technique for getting the GHC runtime to memoize results for you, if you know ahead of time all the values you will ever need to look up. The idea is to turn your recursive function into a self-referential data structure, and then simply look up the value you actually care about. I chose to use Array for this, but a Map would work as well. In either case, the array or map you use must be lazy/non-strict, because we will be inserting values into it that we aren't ready to compute until the whole array is filled.

import Data.Array (array, bounds, inRange, (!))

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2)    -- down was the direction took to reach here
          | dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2)    -- right was the direction took to reach here
          | otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2)     -- is in grid (1,1)
        a = array ((1, 1, 0, 1), (m, n, k, 2))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

I simplified your API a bit:

  • The m and n parameters don't change with each iteration, so they shouldn't be part of the recursive call
  • The client shouldn't have to tell you what i, j, and dir start as, so they've been removed from the function signature and implicitly start at 1, 1, and 0 respectively
  • I also swapped the order of m and n, because it's just weird to take an n parameter first. This caused me quite a bit of headache, because I didn't notice for a while that I also needed to change the base case!

Then, as I said earlier, the idea is to fill up the array with all the recursive calls we'll need to make: that's the array call. Notice the cells in array are initialized with a call to go, which (except for the base case!) involves calling get, which involves looking up an element in the array. In this way, a is self-referential or recursive. But we don't have to decide what order to look things up in, or what order to insert them in: we're sufficiently lazy that GHC evaluates the array elements as needed.

I've also been a bit cheeky by only making space in the array for dir=1 and dir=2, not dir=0. I get away with this because dir=0 only happens on the first call, and I can call go directly for that case, bypassing the bounds-checking in get. This trick does mean you'll get a runtime error if you pass an m or n less than 1, or a k less than zero. You could add a guard for that to paths itself, if you need to handle that case.

And of course, it does indeed work:

> paths 3 3 2
4

One other thing you could do would be to use a real data type for your direction, instead of an Int:

import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)

data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | otherwise = case dir of
            Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
            Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
            Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
        a = array ((1, 1, 0, Down), (m, n, k, Right))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

(I and J might be better names than Down and Right, I don't know if that's easier or harder to remember). I think this is probably an improvement, since the types have more meaning now, and you don't have this weird otherwise clause that handles things like dir=7 which ought to be illegal. But it is still a bit wonky because it relies on the ordering of the enum values: it would break if we put Neutral in between Down and Right. (I tried removing the Neutral direction entirely and adding more special-casing for the first step, but this gets ugly in its own way)

like image 75
amalloy Avatar answered Feb 13 '23 01:02

amalloy


In Haskell these kinds of things aren't the most trivial ones, indeed. You would really like to have some in-place mutations going on to save up on memory and time, so I don't see any better way than equipping the frightening ST monad.

This could be done over various data structures, arrays, vectors, repa tensors. I chose HashTable from hashtables because it is the simplest to use and is performant enough to make sense in my example.


First of all, introduction:

{-# LANGUAGE Rank2Types #-}
module Solution where

import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT

Rank2Types are useful when dealing with ST, because of the phantom types. I picked the Basic variant of the hashtable, because authors claim it has the fastest lookups --- and we are going to lookup a lot.

It is advised to use a type alias for the map, so here we go:

type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer

ST-free entrypoint just to create the map and call our monster:

runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
  mem <- HT.new
  paths mem i j n m k dir

Here is memorized computation of paths. We just try to search for the result in the map, and if it is not there then we save it and return:

mempaths mem i j n m k dir = do
  res <- HT.lookup mem (i, j, k, dir)
  case res of
    Just x -> return x
    Nothing -> do
      x <- paths mem i j n m k dir
      HT.insert mem (i, j, k, dir) x
      return x

And here goes the brain of the algorithm. It is just a monadic action that uses calls with memorization in place of plain recursion:

paths mem i j n m k dir
    | i > n || j > m || k < 0 = return 0
    | i == n && j == m = return 1
    | dir == 0 = do
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m k 2        -- is in grid (1,1)
        return $ x1 + x2
    | dir == 1 = do 
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m (k-1) 2    -- down was the direction took to reach here
        return $ x1 + x2
    | dir == 2 = do
        x1 <- mempaths mem (i+1) j n m (k-1) 1 
        x2 <- mempaths mem i (j+1) n m k 2    -- right was the direction took to reach here 
        return $ x1 + x2
    | otherwise = return (-1)
like image 22
radrow Avatar answered Feb 13 '23 00:02

radrow