Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scala: memoize a function no matter how many arguments the function takes?

Tags:

scala

i want to write a memoize function in scala that can be applied to any function object no matter what that function object is. i want to do so in a way that lets me use a single implementation of memoize. i'm flexible about the syntax, but ideally the memoize appears somewhere very close to the declaration of the function as opposed to after the function. i'd also like to avoid first declaring the original function and then a second declaration for the memoized version.

so some ideal syntax might be this:

def slowFunction(<some args left intentionally vague>) = memoize {
  // the original implementation of slow function
}

or even this would be acceptable:

def slowFUnction = memoize { <some args left intentionally vague> => {
  // the original implementation of slow function
}}

i've seen ways to do this where memoize must be redefined for each arity function, but i want to avoid this approach. the reason is that i will need to implement dozens of functions similar to memoize (i.e. other decorators) and it's too much to ask to have to copy each one for each arity function.

one way to do memoize that does require you to repeat memoize declarations (so it's no good) is at What type to use to store an in-memory mutable data table in Scala?.

like image 647
Heinrich Schmetterling Avatar asked May 03 '11 21:05

Heinrich Schmetterling


People also ask

What does it mean to Memoize a function?

In programming, memoization is an optimization technique that makes applications more efficient and hence faster. It does this by storing computation results in cache, and retrieving that same information from the cache the next time it's needed instead of computing it again.

Can every function be Memoized?

A function can only be memoized if it is referentially transparent; that is, only if calling the function has exactly the same effect as replacing that function call with its return value.

What is Memoization in Scala?

Memoization is an optimization technique of caching the output of an expensive function for a particular input and then returning the cached result if the function is called again with the same input parameters.


2 Answers

You can use a type-class approach to deal with the arity issue. You will still need to deal with each function arity you want to support, but not for every arity/decorator combination:

/**
 * A type class that can tuple and untuple function types.
 * @param [U] an untupled function type
 * @param [T] a tupled function type
 */
sealed class Tupler[U, T](val tupled: U => T, 
                          val untupled: T => U)

object Tupler {
   implicit def function0[R]: Tupler[() => R, Unit => R] =
      new Tupler((f: () => R) => (_: Unit) => f(),
                 (f: Unit => R) => () => f(()))
   implicit def function1[T, R]: Tupler[T => R, T => R] = 
      new Tupler(identity, identity)
   implicit def function2[T1, T2, R]: Tupler[(T1, T2) => R, ((T1, T2)) => R] = 
      new Tupler(_.tupled, Function.untupled[T1, T2, R]) 
   // ... more tuplers
}

You can then implement the decorator as follows:

/**
 * A memoized unary function.
 *
 * @param f A unary function to memoize
 * @param [T] the argument type
 * @param [R] the return type
 */
class Memoize1[-T, +R](f: T => R) extends (T => R) {
   // memoization implementation
}

object Memoize {
   /**
    * Memoize a function.
    *
    * @param f the function to memoize
    */
   def memoize[T, R, F](f: F)(implicit e: Tupler[F, T => R]): F = 
      e.untupled(new Memoize1(e.tupled(f)))
}

Your "ideal" syntax won't work because the compiler would assume that the block passed into memoize is a 0-argument lexical closure. You can, however, use your latter syntax:

// edit: this was originally (and incorrectly) a def
lazy val slowFn = memoize { (n: Int) => 
   // compute the prime decomposition of n
}

Edit:

To eliminate a lot of the boilerplate for defining new decorators, you can create a trait:

trait FunctionDecorator {
   final def apply[T, R, F](f: F)(implicit e: Tupler[F, T => R]): F = 
      e.untupled(decorate(e.tupled(f)))

   protected def decorate[T, R](f: T => R): T => R
}

This allows you to redefine the Memoize decorator as

object Memoize extends FunctionDecorator {
   /**
    * Memoize a function.
    *
    * @param f the function to memoize
    */
   protected def decorate[T, R](f: T => R) = new Memoize1(f)
}

Rather than invoking a memoize method on the Memoize object, you apply the Memoize object directly:

// edit: this was originally (and incorrectly) a def
lazy val slowFn = Memoize(primeDecomposition _)

or

lazy val slowFn = Memoize { (n: Int) =>
   // compute the prime decomposition of n
}
like image 151
Aaron Novstrup Avatar answered Oct 17 '22 07:10

Aaron Novstrup


Library

Use Scalaz's scalaz.Memo

Manual

Below is a solution similar to Aaron Novstrup's answer and this blog, except with some corrections/improvements, brevity and easier for peoples copy and paste needs :)

import scala.Predef._

class Memoized[-T, +R](f: T => R) extends (T => R) {

  import scala.collection.mutable

  private[this] val vals = mutable.Map.empty[T, R]

  def apply(x: T): R = vals.getOrElse(x, {
      val y = f(x)
      vals += ((x, y))
      y
    })
}

// TODO Use macros
// See si9n.com/treehugger/
// http://stackoverflow.com/questions/11400705/code-generation-with-scala
object Tupler {
  implicit def t0t[R]: (() => R) => (Unit) => R = (f: () => R) => (_: Unit) => f()

  implicit def t1t[T, R]: ((T) => R) => (T) => R = identity

  implicit def t2t[T1, T2, R]: ((T1, T2) => R) => ((T1, T2)) => R = (_: (T1, T2) => R).tupled

  implicit def t3t[T1, T2, T3, R]: ((T1, T2, T3) => R) => ((T1, T2, T3)) => R = (_: (T1, T2, T3) => R).tupled

  implicit def t0u[R]: ((Unit) => R) => () => R = (f: Unit => R) => () => f(())

  implicit def t1u[T, R]: ((T) => R) => (T) => R = identity

  implicit def t2u[T1, T2, R]: (((T1, T2)) => R) => ((T1, T2) => R) = Function.untupled[T1, T2, R]

  implicit def t3u[T1, T2, T3, R]: (((T1, T2, T3)) => R) => ((T1, T2, T3) => R) = Function.untupled[T1, T2, T3, R]
}

object Memoize {
  final def apply[T, R, F](f: F)(implicit tupled: F => (T => R), untupled: (T => R) => F): F =
    untupled(new Memoized(tupled(f)))

  //I haven't yet made the implicit tupling magic for this yet
  def recursive[T, R](f: (T, T => R) => R) = {
    var yf: T => R = null
    yf = Memoize(f(_, yf))
    yf
  }
}

object ExampleMemoize extends App {

  val facMemoizable: (BigInt, BigInt => BigInt) => BigInt = (n: BigInt, f: BigInt => BigInt) => {
    if (n == 0) 1
    else n * f(n - 1)
  }

  val facMemoized = Memoize1.recursive(facMemoizable)

  override def main(args: Array[String]) {
    def myMethod(s: Int, i: Int, d: Double): Double = {
      println("myMethod ran")
      s + i + d
    }

    val myMethodMemoizedFunction: (Int, Int, Double) => Double = Memoize(myMethod _)

    def myMethodMemoized(s: Int, i: Int, d: Double): Double = myMethodMemoizedFunction(s, i, d)

    println("myMemoizedMethod(10, 5, 2.2) = " + myMethodMemoized(10, 5, 2.2))
    println("myMemoizedMethod(10, 5, 2.2) = " + myMethodMemoized(10, 5, 2.2))

    println("myMemoizedMethod(5, 5, 2.2) = " + myMethodMemoized(5, 5, 2.2))
    println("myMemoizedMethod(5, 5, 2.2) = " + myMethodMemoized(5, 5, 2.2))

    val myFunctionMemoized: (Int, Int, Double) => Double = Memoize((s: Int, i: Int, d: Double) => {
      println("myFunction ran")
      s * i + d + 3
    })

    println("myFunctionMemoized(10, 5, 2.2) = " + myFunctionMemoized(10, 5, 2.2))
    println("myFunctionMemoized(10, 5, 2.2) = " + myFunctionMemoized(10, 5, 2.2))

    println("myFunctionMemoized(7, 6, 3.2) = " + myFunctionMemoized(7, 6, 3.2))
    println("myFunctionMemoized(7, 6, 3.2) = " + myFunctionMemoized(7, 6, 3.2))
  }
}

When you run ExampleMemoize you will get:

myMethod ran
myMemoizedMethod(10, 5, 2.2) = 17.2
myMemoizedMethod(10, 5, 2.2) = 17.2
myMethod ran
myMemoizedMethod(5, 5, 2.2) = 12.2
myMemoizedMethod(5, 5, 2.2) = 12.2
myFunction ran
myFunctionMemoized(10, 5, 2.2) = 55.2
myFunctionMemoized(10, 5, 2.2) = 55.2
myFunction ran
myFunctionMemoized(7, 6, 3.2) = 48.2
myFunctionMemoized(7, 6, 3.2) = 48.2
like image 33
samthebest Avatar answered Oct 17 '22 06:10

samthebest