Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In this parameterized Scala function, why do I need the cast?

Tags:

scala

In this parameterized function, why do I need the cast? And how I can I get rid of it?

/** Filters `xs` to have only every nth element.
  */
def everyNth[A <% Iterable[B], B](xs: A, n: Int, offset: Int = 0): A =
  (xs.zipWithIndex collect { case (x, i) if (i - offset) % n == 0 => x }).asInstanceOf[A]

If I don't have the cast at the end, I get this error message:

type mismatch; found : Iterable[B] required: A

This function (with the cast) works for all the cases I've tried it on, and I know from typing things like the following at the REPL that Scala is able to infer the result type properly when not in the context of a parameterized function:

scala> val a: Stream[Int] = (Stream.from(0).zipWithIndex collect { case (x, i) if (i + 3) % 5 == 0 => x })
a: Stream[Int] = Stream(2, ?)

scala> a take 10 force
res20: scala.collection.immutable.Stream[Int] = Stream(2, 7, 12, 17, 22, 27, 32, 37, 42, 47)

Please explain!

like image 930
Douglas Avatar asked Aug 04 '12 05:08

Douglas


2 Answers

As per some some suggestions in comments, I looked into CanBuildFrom, and this is what I came up with:

import scala.collection.IterableLike
import scala.collection.generic.CanBuildFrom

/** Filters `xs` to have only every nth element.
  */
def everyNth[A, It <: Iterable[A]]
        (xs: It with IterableLike[A, It], n: Int, offset: Int = 0)
        (implicit bf: CanBuildFrom[It, A , It]): It = {
  val retval = bf()
  retval ++= xs.zipWithIndex collect { case (x, i) if (i - offset) % n == 0 => x }
  retval.result     
}

Yay, it works!!!

And there's NO cast. As such, it even works for Ranges.

However, having to start with an empty retval and then use "++=" to fill it up seems a bit inelegant, so if anyone has a more elegant solution, I'm all ears.

Here's another generic function I implemented that was a bit trickier than the above because the return type is not the same as the argument type. I.e., the input is a sequence of A's, but the output is a sequence of (A, A)'s:

def zipWithSelf[A, It[A] <: Iterable[A]]
        (xs: It[A] with IterableLike[A, It[A]])
        (implicit bf:  CanBuildFrom[It[A], (A, A), It[(A, A)]]): It[(A, A)] = {
    val retval = bf()
    if (xs.nonEmpty) {
      retval ++= xs zip xs.tail
      retval.result
  } else retval.result
}

And here's another:

/** Calls `f(x)` for all x in `xs` and returns an Iterable containing the indexes for
  * which `f(x)` is true.
  *
  * The type of the returned Iterable will match the type of `xs`. 
  */
def findAll[A, It[A] <: Iterable[A]]
        (xs: It[A] with IterableLike[A, It[A]])
        (f: A => Boolean)
        (implicit bf:  CanBuildFrom[It[A], Int, It[Int]]): It[Int] = {
    val retval = bf()
    retval ++= xs.zipWithIndex filter { p => f(p._1) } map { _._2 }
    retval.result
}

I still don't have any deep understanding of the "Like" types and CanBuildFrom, but I get the gist. And it's easy enough in most cases to write the casting version of a generic function as a first pass, and then add the CanBuildFrom and IterableLike boilerplate to make the function more general and fully type-safe.

like image 99
Douglas Avatar answered Oct 31 '22 06:10

Douglas


There are some cases where collect does not return the same subtype of Iterable as it was called on, for instance in the case of a Range:

scala> everyNth(1 to 10, 2)
java.lang.ClassCastException: scala.collection.immutable.Vector cannot be cast to scala.collection.immutable.Range$Inclusive
        at .<init>(<console>:9)
        at .<clinit>(<console>)
        at .<init>(<console>:11)
        at .<clinit>(<console>)
        at $print(<console>)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
        at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:616)
        at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:704)
        at scala.tools.nsc.interpreter.IMain$Request$$anonfun$14.apply(IMain.scala:920)
        at scala.tools.nsc.interpreter.Line$$anonfun$1.apply$mcV$sp(Line.scala:43)
        at scala.tools.nsc.io.package$$anon$2.run(package.scala:25)
        at java.lang.Thread.run(Thread.java:679)
like image 43
Kim Stebel Avatar answered Oct 31 '22 07:10

Kim Stebel