Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

spark aggregation for array column

I have a dataframe with a array column.

val json = """[
{"id": 1, "value": [11, 12, 18]},
{"id": 2, "value": [23, 21, 29]}
]"""

val df = spark.read.json(Seq(json).toDS)

scala> df.show
+---+------------+
| id|       value|
+---+------------+
|  1|[11, 12, 18]|
|  2|[23, 21, 29]|
+---+------------+

Now I need to apply different aggregate functions to the value column. I can call explode and groupBy, for example

df.select($"id", explode($"value").as("value")).groupBy($"id").agg(max("value"), avg("value")).show

+---+----------+------------------+
| id|max(value)|        avg(value)|
+---+----------+------------------+
|  1|        18|13.666666666666666|
|  2|        29|24.333333333333332|
+---+----------+------------------+

What bothers me here is that I explode my DataFrame into a bigger one and then reduce it to the original calling groupBy.

Is there a better (i.e. more efficient) way to call aggregated functions on array column? Probably I can implement UDF but I don't want to implement all aggregation UDFs myself.

EDIT. Someone referenced this SO question but it doesn't work in my case. The size is working fine

scala> df.select($"id", size($"value")).show
+---+-----------+
| id|size(value)|
+---+-----------+
|  1|          3|
|  2|          3|
+---+-----------+

But avg or max do not work.

like image 745
Oleg Pavliv Avatar asked Jan 27 '23 09:01

Oleg Pavliv


1 Answers

The short answer is no, you have to implement your own UDF to aggregate over an array column. At least in the latest version of Spark (2.3.1 at time of writing). Which as you correctly assert is not very efficient as it forces you to either explode the rows or pay the serialization and deserilization cost of working within the Dataset API.

For others who might find this question, to write aggregations in a type-safe way with Datasets you can use the Aggregator API, which admittedly is not well documented and is very messy to work with as the type signatures become quite verbose.

The longer answer is that this functionality is coming soon(?) in Apache Spark 2.4.

The parent issue SPARK-23899 adds:

  • array_max
  • array_min
  • aggregate
  • map
  • array_distinct
  • array_remove
  • array_join

and many others

Screencap slide 11 of Extending Spark SQL API with Easier to Use Array Types Operations

This talk "Extending Spark SQL API with Easier to Use Array Types Operations" was presented at the June 2018 Spark + AI Summit and covers the new functionality.

If it were released that would allow you to use the max function as in your example, however average is a little trickier. Strangely, array_sum is not present, but it could be built from the aggregate function. It would probably look something like:

def sum_array(array_col: Column) = aggregate($"my_array_col", 0, (s, x) => s + x, s => s) df.select(sum_array($"my_array_col") Where the zero value is the initial state of the aggregate buffer.

As you pointed out size can already obtain the length of the array, which means it would be possible to calculate the average.

like image 169
Wade Jensen Avatar answered Feb 05 '23 01:02

Wade Jensen