Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark Window aggregation vs. Group By/Join performance

I'm interested in the performance characteristics of running aggregate functions over a window, compared to group by/join. In this case I'm not interested in window functions with custom frame boundaries or ordering, but only as a way to run aggregate functions.

Note that I'm interested in batch (non-streaming) performance for decently sized amounts of data only, so I've disabled broadcast joins for the following.

For example let's say we start with the following DataFrame:

val df = Seq(("bob", 10), ("sally", 32), ("mike", 9), ("bob", 18)).toDF("name", "age")
df.show(false)

+-----+---+
|name |age|
+-----+---+
|bob  |10 |
|sally|32 |
|mike |9  |
|bob  |18 |
+-----+---+

Let's say we want to count the number of times each name appears, and then provide that count on rows with the matching name.

Group By/Join

val joinResult = df.join(
    df.groupBy($"name").count,
    Seq("name"),
    "inner"
)
joinResult.show(false)

+-----+---+-----+
|name |age|count|
+-----+---+-----+
|sally|32 |1    |
|mike |9  |1    |
|bob  |18 |2    |
|bob  |10 |2    |
+-----+---+-----+

joinResult.explain
== Physical Plan ==
*(4) Project [name#5, age#6, count#12L]
+- *(4) SortMergeJoin [name#5], [name#15], Inner
   :- *(1) Sort [name#5 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(name#5, 200)
   :     +- LocalTableScan [name#5, age#6]
   +- *(3) Sort [name#15 ASC NULLS FIRST], false, 0
      +- *(3) HashAggregate(keys=[name#15], functions=[count(1)])
         +- Exchange hashpartitioning(name#15, 200)
            +- *(2) HashAggregate(keys=[name#15], functions=[partial_count(1)])
               +- LocalTableScan [name#15]

Window

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{functions => f}

val windowResult = df.withColumn("count", f.count($"*").over(Window.partitionBy($"name")))
windowResult.show(false)

+-----+---+-----+
|name |age|count|
+-----+---+-----+
|sally|32 |1    |
|mike |9  |1    |
|bob  |10 |2    |
|bob  |18 |2    |
+-----+---+-----+

windowResult.explain
== Physical Plan ==
Window [count(1) windowspecdefinition(name#5, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS count#34L], [name#5]
+- *(1) Sort [name#5 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(name#5, 200)
      +- LocalTableScan [name#5, age#6]

Based on the execution plans it looks like Windowing is more efficient (fewer stages). So my question is whether that's always the case -- should I always use Window functions for this kind of aggregation? Are the two methods going to scale similarly as data grows? What about with extreme skew (i.e. some names are a lot more common than others)?

like image 766
user1302130 Avatar asked Jun 17 '20 13:06

user1302130


People also ask

Are Window functions faster than group by?

The execution plan for the windowed function is clearly inferior to the execution plan for the group by (the whole group by execution plan is included in the top right of the windowed function execution plan). SQL in general is better optimized for group by than using an over clause, however each has their uses.

What is the use of window function in spark?

Window functions allow users of Spark SQL to calculate results such as the rank of a given row or a moving average over a range of input rows. They significantly improve the expressiveness of Spark's SQL and DataFrame APIs.

What is window unboundedPreceding?

unboundedPreceding , or any value less than or equal to -9223372036854775808. endint. boundary end, inclusive. The frame is unbounded if this is Window. unboundedFollowing , or any value greater than or equal to 9223372036854775807.

What is spark lag?

About LAG function Spark LAG function provides access to a row at a given offset that comes before the current row in the windows. This function can be used in a SELECT statement to compare values in the current row with values in a previous row.


2 Answers

It depends on the data. More specifically here it depends on the cardinality of the name column. If the cardinality is small, the data will be small after the aggregation and the aggregated result can be broadcasted in the join. In that case, the join will be faster than the window. On the other hand, if the cardinality is big and the data is large after the aggregation, so the join will be planed with SortMergeJoin, using window will be more efficient.

In the case of window we have 1 total shuffle + one sort. In the case of SortMergeJoin we have the same in the left branch (total shuffle + sort) plus additional reduced shuffle and sort in the right branch (by reduced I mean that the data is aggregated first). In the right branch of the join we have also additional scan over the data.

Also, you can check my video from the Spark Summit where I analyze similar example.

like image 82
David Vrba Avatar answered Oct 22 '22 06:10

David Vrba


Disabling the broadcast as you state and generating some data with timing approach for 1M & 2M names randomly generated, aka decent size, the execution time for plan 2 appears to indeed be better. 8, 8, 200 partition sizes on a databricks cluster (community).

The generated plan has smarts for the sort and counting via window & as you say less stages. That appears to be the clincher. At scale, you can have more partitions, but the evidence sways me to approach 2.

I tried random samples of names (left out age) and got this:

  • join in 48.361 seconds vs 22.028 seconds for window for 1M records for.count

  • join in 85.814 seconds vs 50.566 seconds for window for 2M records for .count after cluster restart

  • join in 96.295 seconds vs 43.875 seconds for window for 2M records for .count

Code used:

import scala.collection.mutable.ListBuffer
import scala.util.Random
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{functions => f}

val alpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
val size = alpha.size
def randStr(n:Int) = (1 to n).map(_ => alpha(Random.nextInt(size))).mkString

def timeIt[T](op: => T): Float = {
  val start = System.currentTimeMillis
  val res = op
  val end = System.currentTimeMillis
  (end - start) / 1000f
}

var names = new ListBuffer[String]()
for (i <- 1 to 2000000 ) {
    names += randStr(10)     
}
val namesList = names.toSeq
val df = namesList.toDF("name")

val joinResult = df.join(df.groupBy($"name").count, Seq("name"), "inner")
val windowResult = df.withColumn("count", f.count($"*").over(Window.partitionBy($"name")))
val time1 = timeIt(joinResult.count)
val time2 = timeIt(windowResult.count)

println(s"join in $time1 seconds vs $time2 seconds for window")

Moreover, the question demonstrates the immaturity of the Spark Optimizer still.

like image 35
thebluephantom Avatar answered Oct 22 '22 05:10

thebluephantom