Regular scala collections have a nifty collect
method which lets me do a filter-map
operation in one pass using a partial function. Is there an equivalent operation on spark Dataset
s?
I'd like it for two reasons:
filter-map
style operations to a single pass (although in spark I am guessing there are optimizations which spot these things for you)Here is an example to show what I mean. Suppose I have a sequence of options and I want to extract and double just the defined integers (those in a Some
):
val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5))
Method 1 - collect
input.collect {
case Some(value) => value * 2
}
// List(6, -2, 8, 10)
The collect
makes this quite neat syntactically and does one pass.
Method 2 - filter-map
input.filter(_.isDefined).map(_.get * 2)
I can carry this kind of pattern over to spark because datasets and data frames have analogous methods.
But I don't like this so much because isDefined
and get
seem like code smells to me. There's an implicit assumption that map is receiving only Some
s. The compiler can't verify this. In a bigger example, that assumption would be harder for a developer to spot and the developer might swap the filter and map around for example without getting a syntax error.
Method 3 - fold*
operations
input.foldRight[List[Int]](Nil) {
case (nextOpt, acc) => nextOpt match {
case Some(next) => next*2 :: acc
case None => acc
}
}
I haven't used spark enough to know if fold has an equivalent so this might be a bit tangential.
Anyway, the pattern match, the fold boiler plate and the rebuilding of the list all get jumbled together and it's hard to read.
So overall I find the collect
syntax the nicest and I'm hoping spark has something like this.
The answers here are incorrect, at least with the current of Spark.
RDDs do in fact have a collect method that takes a partial function and applies a filter & map to the data. This is completely different from the parameterless .collect() method. See the Spark source code RDD.scala @ line 955:
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
filter(cleanF.isDefinedAt).map(cleanF)
}
This does not materialize the data from the RDD, as opposed to the parameterless .collect() method in RDD.scala @ line 923:
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
In the documentation, notice how the
def collect[U](f: PartialFunction[T, U]): RDD[U]
method does not have a warning associated with it about the data being loaded into the driver's memory:
https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.rdd.RDD@collect[U](f:PartialFunction[T,U])(implicitevidence$29:scala.reflect.ClassTag[U]):org.apache.spark.rdd.RDD[U]
It's very confusing for Spark to have these overloaded methods doing completely different things.
edit: My mistake! I misread the question, we're talking about DataSets not RDDs. Still, the accepted answer says that
"the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."
Which is incorrect! The data is not loaded into the driver's memory when calling the partial function version of .collect() - only when calling the parameterless version. Calling .collect(partial_function) should have about the same performance as calling .filter() and .map() sequentially, as shown in the source code above.
Just for the sake of completeness:
The RDD API does have such a method, so it's always an option to convert a given Dataset / DataFrame to RDD, perform the collect
operation and convert back, e.g.:
val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()
However, this will probably perform worse than using a map and filter on the Dataset (for the reason explained in @stefanobaghino's answer).
As for DataFrames, this particular example (using Option
) is somewhat misleading, as the conversion into a DataFrame actually does the "flatenning" of Options into their values (or null
for None
), so the equivalent expression would be:
val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))
Which, I think, suffers less from your concerns of having the map operation "assume" anything about its input.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With