Why does pattern matching in Spark not work the same as in Scala? See example below... function f()
tries to pattern match on class, which works in the Scala REPL but fails in Spark and results in all "???". f2()
is a workaround that gets the desired result in Spark using .isInstanceOf()
, but I understand that to be bad form in Scala.
Any help on pattern matching the correct way in this scenario in Spark would be greatly appreciated.
abstract class a extends Serializable {val a: Int}
case class b(a: Int) extends a
case class bNull(a: Int=0) extends a
val x: List[a] = List(b(0), b(1), bNull())
val xRdd = sc.parallelize(x)
attempt at pattern matching which works in Scala REPL but fails in Spark
def f(x: a) = x match {
case b(n) => "b"
case bNull(n) => "bnull"
case _ => "???"
}
workaround that functions in Spark, but is bad form (I think)
def f2(x: a) = {
if (x.isInstanceOf[b]) {
"b"
} else if (x.isInstanceOf[bNull]) {
"bnull"
} else {
"???"
}
}
View results
xRdd.map(f).collect //does not work in Spark
// result: Array("???", "???", "???")
xRdd.map(f2).collect // works in Spark
// resut: Array("b", "b", "bnull")
x.map(f(_)) // works in Scala REPL
// result: List("b", "b", "bnull")
Versions used... Spark results run in spark-shell (Spark 1.6 on AWS EMR-4.3) Scala REPL in SBT 0.13.9 (Scala 2.10.5)
This is a known issue with Spark REPL. You can find more details in SPARK-2620. It affects multiple operations in Spark REPL including most of transformations on the PairwiseRDDs
. For example:
case class Foo(x: Int)
val foos = Seq(Foo(1), Foo(1), Foo(2), Foo(2))
foos.distinct.size
// Int = 2
val foosRdd = sc.parallelize(foos, 4)
foosRdd.distinct.count
// Long = 4
foosRdd.map((_, 1)).reduceByKey(_ + _).collect
// Array[(Foo, Int)] = Array((Foo(1),1), (Foo(1),1), (Foo(2),1), (Foo(2),1))
foosRdd.first == foos.head
// Boolean = false
Foo.unapply(foosRdd.first) == Foo.unapply(foos.head)
// Boolean = true
What makes it even worse is that the results depend on the data distribution:
sc.parallelize(foos, 1).distinct.count
// Long = 2
sc.parallelize(foos, 1).map((_, 1)).reduceByKey(_ + _).collect
// Array[(Foo, Int)] = Array((Foo(2),2), (Foo(1),2))
The simplest thing you can do is to define and package required case classes outside REPL. Any code submitted directly using spark-submit
should work as well.
In Scala 2.11+ you can create a package directly in the REPL with paste -raw
.
scala> :paste -raw
// Entering paste mode (ctrl-D to finish)
package bar
case class Bar(x: Int)
// Exiting paste mode, now interpreting.
scala> import bar.Bar
import bar.Bar
scala> sc.parallelize(Seq(Bar(1), Bar(1), Bar(2), Bar(2))).distinct.collect
res1: Array[bar.Bar] = Array(Bar(1), Bar(2))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With