I'm interested in the performance characteristics of running aggregate functions over a window, compared to group by/join. In this case I'm not interested in window functions with custom frame boundaries or ordering, but only as a way to run aggregate functions.
Note that I'm interested in batch (non-streaming) performance for decently sized amounts of data only, so I've disabled broadcast joins for the following.
For example let's say we start with the following DataFrame:
val df = Seq(("bob", 10), ("sally", 32), ("mike", 9), ("bob", 18)).toDF("name", "age")
df.show(false)
+-----+---+
|name |age|
+-----+---+
|bob |10 |
|sally|32 |
|mike |9 |
|bob |18 |
+-----+---+
Let's say we want to count the number of times each name appears, and then provide that count on rows with the matching name.
val joinResult = df.join(
df.groupBy($"name").count,
Seq("name"),
"inner"
)
joinResult.show(false)
+-----+---+-----+
|name |age|count|
+-----+---+-----+
|sally|32 |1 |
|mike |9 |1 |
|bob |18 |2 |
|bob |10 |2 |
+-----+---+-----+
joinResult.explain
== Physical Plan ==
*(4) Project [name#5, age#6, count#12L]
+- *(4) SortMergeJoin [name#5], [name#15], Inner
:- *(1) Sort [name#5 ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(name#5, 200)
: +- LocalTableScan [name#5, age#6]
+- *(3) Sort [name#15 ASC NULLS FIRST], false, 0
+- *(3) HashAggregate(keys=[name#15], functions=[count(1)])
+- Exchange hashpartitioning(name#15, 200)
+- *(2) HashAggregate(keys=[name#15], functions=[partial_count(1)])
+- LocalTableScan [name#15]
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{functions => f}
val windowResult = df.withColumn("count", f.count($"*").over(Window.partitionBy($"name")))
windowResult.show(false)
+-----+---+-----+
|name |age|count|
+-----+---+-----+
|sally|32 |1 |
|mike |9 |1 |
|bob |10 |2 |
|bob |18 |2 |
+-----+---+-----+
windowResult.explain
== Physical Plan ==
Window [count(1) windowspecdefinition(name#5, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS count#34L], [name#5]
+- *(1) Sort [name#5 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(name#5, 200)
+- LocalTableScan [name#5, age#6]
Based on the execution plans it looks like Windowing is more efficient (fewer stages). So my question is whether that's always the case -- should I always use Window functions for this kind of aggregation? Are the two methods going to scale similarly as data grows? What about with extreme skew (i.e. some names are a lot more common than others)?
The execution plan for the windowed function is clearly inferior to the execution plan for the group by (the whole group by execution plan is included in the top right of the windowed function execution plan). SQL in general is better optimized for group by than using an over clause, however each has their uses.
Window functions allow users of Spark SQL to calculate results such as the rank of a given row or a moving average over a range of input rows. They significantly improve the expressiveness of Spark's SQL and DataFrame APIs.
unboundedPreceding , or any value less than or equal to -9223372036854775808. endint. boundary end, inclusive. The frame is unbounded if this is Window. unboundedFollowing , or any value greater than or equal to 9223372036854775807.
About LAG function Spark LAG function provides access to a row at a given offset that comes before the current row in the windows. This function can be used in a SELECT statement to compare values in the current row with values in a previous row.
It depends on the data. More specifically here it depends on the cardinality of the name
column. If the cardinality is small, the data will be small after the aggregation and the aggregated result can be broadcasted in the join. In that case, the join will be faster than the window
. On the other hand, if the cardinality is big and the data is large after the aggregation, so the join will be planed with SortMergeJoin
, using window
will be more efficient.
In the case of window
we have 1 total shuffle + one sort. In the case of SortMergeJoin
we have the same in the left branch (total shuffle + sort) plus additional reduced shuffle and sort in the right branch (by reduced I mean that the data is aggregated first). In the right branch of the join we have also additional scan over the data.
Also, you can check my video from the Spark Summit where I analyze similar example.
Disabling the broadcast as you state and generating some data with timing approach for 1M & 2M names randomly generated, aka decent size, the execution time for plan 2 appears to indeed be better. 8, 8, 200 partition sizes on a databricks cluster (community).
The generated plan has smarts for the sort and counting via window & as you say less stages. That appears to be the clincher. At scale, you can have more partitions, but the evidence sways me to approach 2.
I tried random samples of names (left out age) and got this:
join in 48.361 seconds vs 22.028 seconds for window for 1M records for.count
join in 85.814 seconds vs 50.566 seconds for window for 2M records for .count after cluster restart
join in 96.295 seconds vs 43.875 seconds for window for 2M records for .count
Code used:
import scala.collection.mutable.ListBuffer
import scala.util.Random
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{functions => f}
val alpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
val size = alpha.size
def randStr(n:Int) = (1 to n).map(_ => alpha(Random.nextInt(size))).mkString
def timeIt[T](op: => T): Float = {
val start = System.currentTimeMillis
val res = op
val end = System.currentTimeMillis
(end - start) / 1000f
}
var names = new ListBuffer[String]()
for (i <- 1 to 2000000 ) {
names += randStr(10)
}
val namesList = names.toSeq
val df = namesList.toDF("name")
val joinResult = df.join(df.groupBy($"name").count, Seq("name"), "inner")
val windowResult = df.withColumn("count", f.count($"*").over(Window.partitionBy($"name")))
val time1 = timeIt(joinResult.count)
val time2 = timeIt(windowResult.count)
println(s"join in $time1 seconds vs $time2 seconds for window")
Moreover, the question demonstrates the immaturity of the Spark Optimizer still.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With