Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark Dataset equivalent for scala's "collect" taking a partial function

Regular scala collections have a nifty collect method which lets me do a filter-map operation in one pass using a partial function. Is there an equivalent operation on spark Datasets?


I'd like it for two reasons:

  • syntactic simplicity
  • it reduces filter-map style operations to a single pass (although in spark I am guessing there are optimizations which spot these things for you)

Here is an example to show what I mean. Suppose I have a sequence of options and I want to extract and double just the defined integers (those in a Some):

val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5)) 

Method 1 - collect

input.collect {
  case Some(value) => value * 2
} 
// List(6, -2, 8, 10)

The collect makes this quite neat syntactically and does one pass.

Method 2 - filter-map

input.filter(_.isDefined).map(_.get * 2)

I can carry this kind of pattern over to spark because datasets and data frames have analogous methods.

But I don't like this so much because isDefined and get seem like code smells to me. There's an implicit assumption that map is receiving only Somes. The compiler can't verify this. In a bigger example, that assumption would be harder for a developer to spot and the developer might swap the filter and map around for example without getting a syntax error.

Method 3 - fold* operations

input.foldRight[List[Int]](Nil) {
  case (nextOpt, acc) => nextOpt match {
    case Some(next) => next*2 :: acc
    case None => acc
  }
}

I haven't used spark enough to know if fold has an equivalent so this might be a bit tangential.

Anyway, the pattern match, the fold boiler plate and the rebuilding of the list all get jumbled together and it's hard to read.


So overall I find the collect syntax the nicest and I'm hoping spark has something like this.

like image 604
rmin Avatar asked Jan 25 '17 09:01

rmin


2 Answers

The answers here are incorrect, at least with the current of Spark.

RDDs do in fact have a collect method that takes a partial function and applies a filter & map to the data. This is completely different from the parameterless .collect() method. See the Spark source code RDD.scala @ line 955:

/**
 * Return an RDD that contains all matching values by applying `f`.
 */
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  filter(cleanF.isDefinedAt).map(cleanF)
}

This does not materialize the data from the RDD, as opposed to the parameterless .collect() method in RDD.scala @ line 923:

/**
 * Return an array that contains all of the elements in this RDD.
 */
def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

In the documentation, notice how the

def collect[U](f: PartialFunction[T, U]): RDD[U]

method does not have a warning associated with it about the data being loaded into the driver's memory:

https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.rdd.RDD@collect[U](f:PartialFunction[T,U])(implicitevidence$29:scala.reflect.ClassTag[U]):org.apache.spark.rdd.RDD[U]

It's very confusing for Spark to have these overloaded methods doing completely different things.


edit: My mistake! I misread the question, we're talking about DataSets not RDDs. Still, the accepted answer says that

"the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."

Which is incorrect! The data is not loaded into the driver's memory when calling the partial function version of .collect() - only when calling the parameterless version. Calling .collect(partial_function) should have about the same performance as calling .filter() and .map() sequentially, as shown in the source code above.

like image 94
shoffing Avatar answered Nov 09 '22 13:11

shoffing


Just for the sake of completeness:

The RDD API does have such a method, so it's always an option to convert a given Dataset / DataFrame to RDD, perform the collect operation and convert back, e.g.:

val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()

However, this will probably perform worse than using a map and filter on the Dataset (for the reason explained in @stefanobaghino's answer).

As for DataFrames, this particular example (using Option) is somewhat misleading, as the conversion into a DataFrame actually does the "flatenning" of Options into their values (or null for None), so the equivalent expression would be:

val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))

Which, I think, suffers less from your concerns of having the map operation "assume" anything about its input.

like image 23
Tzach Zohar Avatar answered Nov 09 '22 15:11

Tzach Zohar