Scala Memoization: How does this Scala memo work?

The following code is from Pathikrit's Dynamic Programming repository. I'm mystified by both its beauty and peculiarity.

def subsetSum(s: List[Int], t: Int) = {   type DP = Memo[(List[Int], Int), (Int, Int), Seq[Seq[Int]]]   implicit def encode(key: (List[Int], Int)) = (key._1.length, key._2)    lazy val f: DP = Memo {     case (Nil, 0) => Seq(Nil)     case (Nil, _) => Nil     case (a :: as, x) => (f(as, x - a) map {_ :+ a}) ++ f(as, x)   }    f(s, t) } 

The type Memo is implemented in another file:

case class Memo[I <% K, K, O](f: I => O) extends (I => O) {   import collection.mutable.{Map => Dict}   val cache = Dict.empty[K, O]   override def apply(x: I) = cache getOrElseUpdate (x, f(x)) } 

My questions are:

  1. Why is type K declared as (Int, Int) in subsetSum?

  2. What does the int in (Int, Int) stand for respectively?

3. How does (List[Int], Int) implicitly convert to (Int, Int)?
I see no implicit def foo(x:(List[Int],Int)) = (x._1.toInt,x._2). ( not even in the Implicits.scala file it imports.

*Edit: Well, I miss this:

implicit def encode(key: (List[Int], Int)) = (key._1.length, key._2) 

I enjoy Pathikrit's library scalgos very much. There are a lot of Scala pearls in it. Please help me with this so I can appreciate Pathikrit's wit. Thank you. (:

1 Answers

I am the author of the above code.

/**  * Generic way to create memoized functions (even recursive and multiple-arg ones)  *  * @param f the function to memoize  * @tparam I input to f  * @tparam K the keys we should use in cache instead of I  * @tparam O output of f  */ case class Memo[I <% K, K, O](f: I => O) extends (I => O) {   import collection.mutable.{Map => Dict}   type Input = I   type Key = K   type Output = O   val cache = Dict.empty[K, O]   override def apply(x: I) = cache getOrElseUpdate (x, f(x)) }  object Memo {   /**    * Type of a simple memoized function e.g. when I = K    */   type ==>[I, O] = Memo[I, I, O] } 

In Memo[I <% K, K, O]:

I: input K: key to lookup in cache O: output 

The line I <% K means the K can be viewable (i.e. implicitly converted) from I.

In most cases, I should be K e.g. if you are writing fibonacci which is a function of type Int => Int, it is okay to cache by Int itself.

But, sometimes when you are writing memoization, you do not want to always memoize or cache by the input itself (I) but rather a function of the input (K) e.g when you are writing the subsetSum algorithm which has input of type (List[Int], Int), you do not want to use List[Int] as the key in your cache but rather you want use List[Int].size as the part of the key in your cache.

So, here's a concrete case:

/**  * Subset sum algorithm - can we achieve sum t using elements from s?  * O(s.map(abs).sum * s.length)  *  * @param s set of integers  * @param t target  * @return true iff there exists a subset of s that sums to t  */  def isSubsetSumAchievable(s: List[Int], t: Int): Boolean = {     type I = (List[Int], Int)     // input type     type K = (Int, Int)           // cache key i.e. (list.size, int)     type O = Boolean              // output type            type DP = Memo[I, K, O]      // encode the input as a key in the cache i.e. make K implicitly convertible from I     implicit def encode(input: DP#Input): DP#Key = (input._1.length, input._2)         lazy val f: DP = Memo {       case (Nil, x) => x == 0      // an empty sequence can only achieve a sum of zero       case (a :: as, x) => f(as, x - a) || f(as, x)      // try with/without a.head     }      f(s, t)  } 

You can ofcourse shorten all these into a single line: type DP = Memo[(List[Int], Int), (Int, Int), Boolean]

For the common case (when I = K), you can simply do this: type ==>[I, O] = Memo[I, I, O] and use it like this to calculate the binomial coeff with recursive memoization:

  /**    * http://mathworld.wolfram.com/Combination.html    * @return memoized function to calculate C(n,r)    */   val c: (Int, Int) ==> BigInt = Memo {     case (_, 0) => 1     case (n, r) if r > n/2 => c(n, n - r)     case (n, r) => c(n - 1, r - 1) + c(n - 1, r)   } 

To see details how above syntax works, please refer to this question.

Here is a full example which calculates editDistance by encoding both the parameters of the input (Seq, Seq) to (Seq.length, Seq.length):

 /**    * Calculate edit distance between 2 sequences    * O(s1.length * s2.length)    *    * @return Minimum cost to convert s1 into s2 using delete, insert and replace operations    */   def editDistance[A](s1: Seq[A], s2: Seq[A]) = {      type DP = Memo[(Seq[A], Seq[A]), (Int, Int), Int]     implicit def encode(key: DP#Input): DP#Key = (key._1.length, key._2.length)      lazy val f: DP = Memo {       case (a, Nil) => a.length       case (Nil, b) => b.length       case (a :: as, b :: bs) if a == b => f(as, bs)       case (a, b) => 1 + (f(a, b.tail) min f(a.tail, b) min f(a.tail, b.tail))     }      f(s1, s2)   } 

And lastly, the canonical fibonacci example:

lazy val fib: Int ==> BigInt = Memo {   case 0 => 0   case 1 => 1   case n if n > 1 => fib(n-1) + fib(n-2) }  println(fib(100)) 
