I have an aggregated DataFrame with a column created using collect_set
. I now need to aggregate over this DataFrame again, and apply collect_set
to the values of that column again. The problem is that I need to apply collect_Set
ver the values of the sets - and do far the only way I see how to do so is by exploding the aggregated DataFrame. Is there a better way?
Example:
Initial DataFrame:
country | continent | attributes
-------------------------------------
Canada | America | A
Belgium | Europe | Z
USA | America | A
Canada | America | B
France | Europe | Y
France | Europe | X
Aggregated DataFrame (the one I receive as input) - aggregation over country
:
country | continent | attributes
-------------------------------------
Canada | America | A, B
Belgium | Europe | Z
USA | America | A
France | Europe | Y, X
My desired output - aggregation over continent
:
continent | attributes
-------------------------------------
America | A, B
Europe | X, Y, Z
Since you can have only a handful of rows at this point, you just collect attributes as-is and flatten the result (Spark >= 2.4)
import org.apache.spark.sql.functions.{collect_set, flatten, array_distinct}
val byState = Seq(
("Canada", "America", Seq("A", "B")),
("Belgium", "Europe", Seq("Z")),
("USA", "America", Seq("A")),
("France", "Europe", Seq("Y", "X"))
).toDF("country", "continent", "attributes")
byState
.groupBy("continent")
.agg(array_distinct(flatten(collect_set($"attributes"))) as "attributes")
.show
+---------+----------+
|continent|attributes|
+---------+----------+
| Europe| [Y, X, Z]|
| America| [A, B]|
+---------+----------+
In general case things are much harder to handle, and in many cases, if you expect large lists, with many duplicates and many values per group, the optimal solution* is to just recompute results from scratch, i.e.
input.groupBy($"continent").agg(collect_set($"attributes") as "attributes")
One possible alternative is to use Aggregator
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.{Encoder, Encoders}
import scala.collection.mutable.{Set => MSet}
class MergeSets[T, U](f: T => Seq[U])(implicit enc: Encoder[Seq[U]]) extends
Aggregator[T, MSet[U], Seq[U]] with Serializable {
def zero = MSet.empty[U]
def reduce(acc: MSet[U], x: T) = {
for { v <- f(x) } acc.add(v)
acc
}
def merge(acc1: MSet[U], acc2: MSet[U]) = {
acc1 ++= acc2
}
def finish(acc: MSet[U]) = acc.toSeq
def bufferEncoder: Encoder[MSet[U]] = Encoders.kryo[MSet[U]]
def outputEncoder: Encoder[Seq[U]] = enc
}
and apply it as follows
case class CountryAggregate(
country: String, continent: String, attributes: Seq[String])
byState
.as[CountryAggregate]
.groupByKey(_.continent)
.agg(new MergeSets[CountryAggregate, String](_.attributes).toColumn)
.toDF("continent", "attributes")
.show
+---------+----------+
|continent|attributes|
+---------+----------+
| Europe| [X, Y, Z]|
| America| [B, A]|
+---------+----------+
but that's clearly not a Java-friendly option.
See also How to aggregate values into collection after groupBy? (similar, but without uniqueness constraint).
* That's because explode
can be quite expensive, especially in older Spark versions, same as access to external representation of SQL collections.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With