Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a generic way to memoize in Scala?

I wanted to memoize this:

def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2) println(fib(100)) // times out 

So I wrote this and this surprisingly compiles and works (I am surprised because fib references itself in its declaration):

case class Memo[A,B](f: A => B) extends (A => B) {   private val cache = mutable.Map.empty[A, B]   def apply(x: A) = cache getOrElseUpdate (x, f(x)) }  val fib: Memo[Int, BigInt] = Memo {   case 0 => 0   case 1 => 1   case n => fib(n-1) + fib(n-2)  }  println(fib(100))     // prints 100th fibonacci number instantly 

But when I try to declare fib inside of a def, I get a compiler error:

def foo(n: Int) = {   val fib: Memo[Int, BigInt] = Memo {     case 0 => 0     case 1 => 1     case n => fib(n-1) + fib(n-2)    }   fib(n) }  

Above fails to compile error: forward reference extends over definition of value fib case n => fib(n-1) + fib(n-2)

Why does declaring the val fib inside a def fails but outside in the class/object scope works?

To clarify, why I might want to declare the recursive memoized function in the def scope - here is my solution to the subset sum problem:

/**    * Subset sum algorithm - can we achieve sum t using elements from s?    *    * @param s set of integers    * @param t target    * @return true iff there exists a subset of s that sums to t    */   def subsetSum(s: Seq[Int], t: Int): Boolean = {     val max = s.scanLeft(0)((sum, i) => (sum + i) max sum)  //max(i) =  largest sum achievable from first i elements     val min = s.scanLeft(0)((sum, i) => (sum + i) min sum)  //min(i) = smallest sum achievable from first i elements      val dp: Memo[(Int, Int), Boolean] = Memo {         // dp(i,x) = can we achieve x using the first i elements?       case (_, 0) => true        // 0 can always be achieved using empty set       case (0, _) => false       // if empty set, non-zero cannot be achieved       case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x)  // try with/without s(i-1)       case _ => false            // outside range otherwise     }      dp(s.length, t)   } 
like image 885
pathikrit Avatar asked Apr 27 '13 22:04

pathikrit


People also ask

What is memoization in Scala?

Memoization is an optimization technique of caching the output of an expensive function for a particular input and then returning the cached result if the function is called again with the same input parameters.

How do you use Memoize function?

Importance of Memoization: When a function is given in input, it performs the necessary computation and saves the result in a cache before returning the value. If the same input is received again in the future, it will not be necessary to repeat the process. It would simply return the cached answer from the memory.

What does Memoize mean?

(transitive, computing) To store (the result of a computation) so that it can be subsequently retrieved without repeating the computation.

How do you Memoize a function in C++?

The right way to do memoization in C++ is to mix the Y-combinator in. Your base function needs a modification. Instead of calling itself directly, it takes a templateized reference to itself as its first argument (or, a std::function<Same_Signature> recursion as its first argument).


1 Answers

I found a better way to memoize using Scala:

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {   override def apply(key: I) = getOrElseUpdate(key, f(key)) } 

Now you can write fibonacci as follows:

lazy val fib: Int => BigInt = memoize {   case 0 => 0   case 1 => 1   case n => fib(n-1) + fib(n-2) } 

Here's one with multiple arguments (the choose function):

lazy val c: ((Int, Int)) => BigInt = memoize {   case (_, 0) => 1   case (n, r) if r > n/2 => c(n, n - r)   case (n, r) => c(n - 1, r - 1) + c(n - 1, r) } 

And here's the subset sum problem:

// is there a subset of s which has sum = t def isSubsetSumAchievable(s: Vector[Int], t: Int) = {   // f is (i, j) => Boolean i.e. can the first i elements of s add up to j   lazy val f: ((Int, Int)) => Boolean = memoize {     case (_, 0) => true        // 0 can always be achieved using empty list     case (0, _) => false       // we can never achieve non-zero if we have empty list     case (i, j) =>        val k = i - 1            // try the kth element       f(k, j - s(k)) || f(k, j)   }   f(s.length, t) } 

EDIT: As discussed below, here is a thread-safe version

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self =>   override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key))) } 
like image 134
pathikrit Avatar answered Sep 20 '22 16:09

pathikrit