In researching how to do Memoization in Scala, I've found some code I didn't grok. I've tried to look this particular "thing" up, but don't know by what to call it; i.e. the term by which to refer to it. Additionally, it's not easy searching using a symbol, ugh!
I saw the following code to do memoization in Scala here:
case class Memo[A,B](f: A => B) extends (A => B) {
private val cache = mutable.Map.empty[A, B]
def apply(x: A) = cache getOrElseUpdate (x, f(x))
}
And it's what the case class is extending that is confusing me, the extends (A => B)
part. First, what is happening? Secondly, why is it even needed? And finally, what do you call this kind of inheritance; i.e. is there some specific name or term I can use to refer to it?
Next, I am seeing Memo used in this way to calculate a Fibanocci number here:
val fibonacci: Memo[Int, BigInt] = Memo {
case 0 => 0
case 1 => 1
case n => fibonacci(n-1) + fibonacci(n-2)
}
It's probably my not seeing all of the "simplifications" that are being applied. But, I am not able to figure out the end of the val
line, = Memo {
. So, if this was typed out more verbosely, perhaps I would understand the "leap" being made as to how the Memo is being constructed.
Any assistance on this is greatly appreciated. Thank you.
=> is syntactic sugar for creating instances of functions. Recall that every function in scala is an instance of a class. For example, the type Int => String , is equivalent to the type Function1[Int,String] i.e. a function that takes an argument of type Int and returns a String .
The first thing you inherit from can either be a trait or a class, using the extends keyword. You can define further inherited traits (and only traits) using the with keyword.
A => B
is short for Function1[A, B]
, so your Memo
extends a function from A
to B
, most prominently defined through method apply(x: A): B
which must be defined.
Because of the "infix" notation, you need to put parentheses around the type, i.e. (A => B)
. You could also write
case class Memo[A, B](f: A => B) extends Function1[A, B] ...
or
case class Memo[A, B](f: Function1[A, B]) extends Function1[A, B] ...
To complete 0_'s answer, fibonacci
is being instanciated through the apply method of Memo
's companion object, generated automatically by the compiler since Memo
is a case class.
This means that the following code is generated for you:
object Memo {
def apply[A, B](f: A => B): Memo[A, B] = new Memo(f)
}
Scala has special handling for the apply
method: its name needs not be typed when calling it. The two following calls are strictly equivalent:
Memo((a: Int) => a * 2)
Memo.apply((a: Int) => a * 2)
The case
block is known as pattern matching. Under the hood, it generates a partial function - that is, a function that is defined for some of its input parameters, but not necessarily all of them. I'll not go in the details of partial functions as it's beside the point (this is a memo I wrote to myself on that topic, if you're keen), but what it essentially means here is that the case
block is in fact an instance of PartialFunction.
If you follow that link, you'll see that PartialFunction
extends Function1 - which is the expected argument of Memo.apply
.
So what that bit of code actually means, once desugared (if that's a word), is:
lazy val fibonacci: Memo[Int, BigInt] = Memo.apply(new PartialFunction[Int, BigInt] {
override def apply(v: Int): Int =
if(v == 0) 0
else if(v == 1) 1
else fibonacci(v - 1) + fibonacci(v - 2)
override isDefinedAt(v: Int) = true
})
Note that I've vastly simplified the way the pattern matching is handled, but I thought that starting a discussion about unapply
and unapplySeq
would be off topic and confusing.
I am the original author of doing memoization this way. You can see some sample usages in that same file. It also works really well when you want to memoize on multiple arguments too because of the way Scala unrolls tuples:
/**
* @return memoized function to calculate C(n,r)
* see http://mathworld.wolfram.com/BinomialCoefficient.html
*/
val c: Memo[(Int, Int), BigInt] = Memo {
case (_, 0) => 1
case (n, r) if r > n/2 => c(n, n-r)
case (n, r) => c(n-1, r-1) + c(n-1, r)
}
// note how I can invoke a memoized function on multiple args too
val x = c(10, 3)
This answer is a synthesis of the partial answers provided by both 0__ and Nicolas Rinaudo.
Summary:
There are many convenient (but also highly intertwined) assumptions being made by the Scala compiler.
extends (A => B)
as synonymous with extends Function1[A, B]
(ScalaDoc for Function1[+T1, -R])apply(x: A): B
must be provided; def apply(x: A): B = cache.getOrElseUpdate(x, f(x))
match
for the code block starting with = Memo {
{}
started in item 3 as a parameter to the Memo case class constructor{}
started in item 3 as PartialFunction[Int, BigInt]
and the compiler uses the "match" code block as the override for the PartialFunction method's apply()
and then provides an additional override for the PartialFunction's method isDefinedAt()
. Details:
The first code block defining the case class Memo can be written more verbosely as such:
case class Memo[A,B](f: A => B) extends Function1[A, B] { //replaced (A => B) with what it's translated to mean by the Scala compiler
private val cache = mutable.Map.empty[A, B]
def apply(x: A): B = cache.getOrElseUpdate(x, f(x)) //concrete implementation of unimplemented method defined in parent class, Function1
}
The second code block defining the val fibanocci can be written more verbosely as such:
lazy val fibonacci: Memo[Int, BigInt] = {
Memo.apply(
new PartialFunction[Int, BigInt] {
override def apply(x: Int): BigInt = {
x match {
case 0 => 0
case 1 => 1
case n => fibonacci(n-1) + fibonacci(n-2)
}
}
override def isDefinedAt(x: Int): Boolean = true
}
)
}
Had to add lazy
to the second code block's val in order to deal with a self-referential problem in the line case n => fibonacci(n-1) + fibonacci(n-2)
.
And finally, an example usage of fibonacci is:
val x:BigInt = fibonacci(20) //returns 6765 (almost instantly)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With