Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Scala, what does "extends (A => B)" on a case class mean?

In researching how to do Memoization in Scala, I've found some code I didn't grok. I've tried to look this particular "thing" up, but don't know by what to call it; i.e. the term by which to refer to it. Additionally, it's not easy searching using a symbol, ugh!

I saw the following code to do memoization in Scala here:

case class Memo[A,B](f: A => B) extends (A => B) {
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))
}

And it's what the case class is extending that is confusing me, the extends (A => B) part. First, what is happening? Secondly, why is it even needed? And finally, what do you call this kind of inheritance; i.e. is there some specific name or term I can use to refer to it?

Next, I am seeing Memo used in this way to calculate a Fibanocci number here:

  val fibonacci: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fibonacci(n-1) + fibonacci(n-2)
  }

It's probably my not seeing all of the "simplifications" that are being applied. But, I am not able to figure out the end of the val line, = Memo {. So, if this was typed out more verbosely, perhaps I would understand the "leap" being made as to how the Memo is being constructed.

Any assistance on this is greatly appreciated. Thank you.

like image 892
chaotic3quilibrium Avatar asked Oct 23 '13 17:10

chaotic3quilibrium


People also ask

What is the meaning of => in Scala?

=> is syntactic sugar for creating instances of functions. Recall that every function in scala is an instance of a class. For example, the type Int => String , is equivalent to the type Function1[Int,String] i.e. a function that takes an argument of type Int and returns a String .

What is the difference between extends and with in Scala?

The first thing you inherit from can either be a trait or a class, using the extends keyword. You can define further inherited traits (and only traits) using the with keyword.


4 Answers

A => B is short for Function1[A, B], so your Memo extends a function from A to B, most prominently defined through method apply(x: A): B which must be defined.

Because of the "infix" notation, you need to put parentheses around the type, i.e. (A => B). You could also write

case class Memo[A, B](f: A => B) extends Function1[A, B] ...

or

case class Memo[A, B](f: Function1[A, B]) extends Function1[A, B] ...
like image 168
0__ Avatar answered Sep 22 '22 14:09

0__


To complete 0_'s answer, fibonacci is being instanciated through the apply method of Memo's companion object, generated automatically by the compiler since Memo is a case class.

This means that the following code is generated for you:

object Memo {
  def apply[A, B](f: A => B): Memo[A, B] = new Memo(f)
}

Scala has special handling for the apply method: its name needs not be typed when calling it. The two following calls are strictly equivalent:

Memo((a: Int) => a * 2)

Memo.apply((a: Int) => a * 2)

The case block is known as pattern matching. Under the hood, it generates a partial function - that is, a function that is defined for some of its input parameters, but not necessarily all of them. I'll not go in the details of partial functions as it's beside the point (this is a memo I wrote to myself on that topic, if you're keen), but what it essentially means here is that the case block is in fact an instance of PartialFunction.

If you follow that link, you'll see that PartialFunction extends Function1 - which is the expected argument of Memo.apply.

So what that bit of code actually means, once desugared (if that's a word), is:

lazy val fibonacci: Memo[Int, BigInt] = Memo.apply(new PartialFunction[Int, BigInt] {
  override def apply(v: Int): Int =
    if(v == 0)      0
    else if(v == 1) 1
    else            fibonacci(v - 1) + fibonacci(v - 2)

  override isDefinedAt(v: Int) = true
})

Note that I've vastly simplified the way the pattern matching is handled, but I thought that starting a discussion about unapply and unapplySeq would be off topic and confusing.

like image 23
Nicolas Rinaudo Avatar answered Sep 20 '22 14:09

Nicolas Rinaudo


I am the original author of doing memoization this way. You can see some sample usages in that same file. It also works really well when you want to memoize on multiple arguments too because of the way Scala unrolls tuples:

    /**
     * @return memoized function to calculate C(n,r) 
     * see http://mathworld.wolfram.com/BinomialCoefficient.html
     */
     val c: Memo[(Int, Int), BigInt] = Memo {
        case (_, 0) => 1
        case (n, r) if r > n/2 => c(n, n-r)
        case (n, r) => c(n-1, r-1) + c(n-1, r)
     }
     // note how I can invoke a memoized function on multiple args too
     val x = c(10, 3) 
like image 24
pathikrit Avatar answered Sep 19 '22 14:09

pathikrit


This answer is a synthesis of the partial answers provided by both 0__ and Nicolas Rinaudo.

Summary:

There are many convenient (but also highly intertwined) assumptions being made by the Scala compiler.

  1. Scala treats extends (A => B) as synonymous with extends Function1[A, B] (ScalaDoc for Function1[+T1, -R])
  2. A concrete implementation of Function1's inherited abstract method apply(x: A): B must be provided; def apply(x: A): B = cache.getOrElseUpdate(x, f(x))
  3. Scala assumes an implied match for the code block starting with = Memo {
  4. Scala passes the content between {} started in item 3 as a parameter to the Memo case class constructor
  5. Scala assumes an implied type between {} started in item 3 as PartialFunction[Int, BigInt] and the compiler uses the "match" code block as the override for the PartialFunction method's apply() and then provides an additional override for the PartialFunction's method isDefinedAt().

Details:

The first code block defining the case class Memo can be written more verbosely as such:

case class Memo[A,B](f: A => B) extends Function1[A, B] {    //replaced (A => B) with what it's translated to mean by the Scala compiler
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A): B = cache.getOrElseUpdate(x, f(x))  //concrete implementation of unimplemented method defined in parent class, Function1
}

The second code block defining the val fibanocci can be written more verbosely as such:

lazy val fibonacci: Memo[Int, BigInt] = {
  Memo.apply(
    new PartialFunction[Int, BigInt] {
      override def apply(x: Int): BigInt = {
        x match {
          case 0 => 0
          case 1 => 1
          case n => fibonacci(n-1) + fibonacci(n-2)
        }
      }
      override def isDefinedAt(x: Int): Boolean = true
    }
  )
}

Had to add lazy to the second code block's val in order to deal with a self-referential problem in the line case n => fibonacci(n-1) + fibonacci(n-2).

And finally, an example usage of fibonacci is:

val x:BigInt = fibonacci(20) //returns 6765 (almost instantly)
like image 28
chaotic3quilibrium Avatar answered Sep 20 '22 14:09

chaotic3quilibrium