Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

spark windowing function VS group by performance issue

I have a dataframe done like

| id | date      |  KPI_1 | ... | KPI_n
| 1  |2012-12-12 |   0.1  | ... |  0.5
| 2  |2012-12-12 |   0.2  | ... |  0.4
| 3  |2012-12-12 |   0.66 | ... |  0.66 
| 1  |2012-12-13 |   0.2  | ... |  0.46
| 4  |2012-12-14 |   0.2  | ... |  0.45 
| ...
| 55| 2013-03-15 |  0.5  | ... |  0.55

we have

  • X identifiers
  • a row for every identifier for a given date
  • n KPIs

I have to calculate some derived KPI for every row, and this KPI depends on the previous values of every ID. Let's say my derived KPI is a diff, it would be:

| id | date      |  KPI_1 | ... | KPI_n | KPI_1_diff | KPI_n_diff
| 1  |2012-12-12 |   0.1  | ... |  0.5  |   0.1      | 0.5
| 2  |2012-12-12 |   0.2  | ... |  0.4  |   0.2      |0.4
| 3  |2012-12-12 |   0.66 | ... |  0.66 |   0.66     | 0.66 
| 1  |2012-12-13 |   0.2  | ... |  0.46 |   0.2-0.1  | 0.46 - 0.66
| 4  |2012-12-13 |   0.2  | ... |  0.45  ...
| ...
| 55| 2013-03-15 |  0.5  | ... |  0.55

Now: what I would do is:

val groupedDF = myDF.groupBy("id").agg(
    collect_list(struct(col("date",col("KPI_1"))).as("wrapped_KPI_1"),
    collect_list(struct(col("date",col("KPI_2"))).as("wrapped_KPI_2")
    // up until nth KPI
)

I would get aggregated data such as:

[("2012-12-12",0.1),("2012-12-12",0.2) ...

Then I would sort these wrapped data, unwrap and map over these aggregated result with some UDF and produce the output (compute diffs and other statistics).

Another approach is to use the window functions such as:

val window = Window.partitionBy(col("id")).orderBy(col("date")).rowsBetween(Window.unboundedPreceding,0L) 

and do :

val windowedDF = df.select (
  col("id"),
  col("date"),
  col("KPI_1"),
  collect_list(struct(col("date"),col("KPI_1"))).over(window),  
  collect_list(struct(col("date"),col("KPI_2"))).over(window)
   )   

This way I get:

[("2012-12-12",0.1)]
[("2012-12-12",0.1), ("2012-12-13",0.1)]
...

That look nicer to process, but I suspect that repeating the window would produce unnecessary grouping and sorting for every KPI.

So here are the questions:

  1. I'd rather go for the grouping approach?
  2. Would I go for the window? If so what is the most efficient approach to do it?
like image 369
JayZee Avatar asked Jan 23 '19 17:01

JayZee


People also ask

Is window function faster than group by?

The execution plan for the windowed function is clearly inferior to the execution plan for the group by (the whole group by execution plan is included in the top right of the windowed function execution plan). SQL in general is better optimized for group by than using an over clause, however each has their uses.

How does window function work in spark?

Spark Window functions are used to calculate results such as the rank, row number e.t.c over a range of input rows and these are available to you by importing org. apache. spark. sql.

What is the purpose of the windowing function?

Window functions perform calculations on a set of rows that are related together. But, unlike the aggregate functions, windowing functions do not collapse the result of the rows into a single value. Instead, all the rows maintain their original identity and the calculated result is returned for every row.


1 Answers

I believe the window approach should be a better solution but before using the window functions you should re-partition the dataframe based on id. This will shuffle the data only once and all the window functions should be executed with already shuffled dataframe. I hope it helps.

The code should be something like this.

val windowedDF = df.repartition(col("id"))
  .select (
  col("id"),
  col("date"),
  col("KPI_1"),
  col("KPI_2"),
  collect_list(struct(col("date"),col("KPI_1"))).over(window),
  collect_list(struct(col("date"),col("KPI_2"))).over(window)
)

@Raphael Roth

Here, we are aggregating over a single window. That is why you might be seeing same execution plan. Please see the example below where aggregation over multiple window can be done from one partition only.

val list = Seq(( "2", null, 1, 11, 1, 1 ),
  ( "2", null, 1, 22, 2, 2 ),
  ( "2", null, 1, 11, 1, 3 ),
  ( "2", null, 1, 22, 2, 1 ),
  ( "2", null, 1, 33, 1, 2 ),
  ( null, "3", 3, 33, 1, 2 ),
  ( null, "3", 3, 33, 2, 3 ),
  ( null, "3", 3, 11, 1, 1 ),
  ( null, "3", 3, 22, 2, 2 ),
  ( null, "3", 3, 11, 1, 3 )
)

val df = spark.sparkContext.parallelize(list).toDF("c1","c2","batchDate","id", "pv" , "vv")

val c1Window = Window.partitionBy("batchDate", "c1")
val c2Window = Window.partitionBy("batchDate", "c2")

val agg1df = df.withColumn("c1List",collect_list("pv").over(c1Window))
  .withColumn("c2List", collect_list("pv").over(c2Window))

val agg2df = df.repartition($"batchDate")
  .withColumn("c1List",collect_list("pv").over(c1Window))
  .withColumn("c2List", collect_list("pv").over(c2Window))


agg1df.explain()
== Physical Plan ==
Window [collect_list(pv#18, 0, 0) windowspecdefinition(batchDate#16, c2#15, ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS c2List#38], [batchDate#16, c2#15]
+- *Sort [batchDate#16 ASC NULLS FIRST, c2#15 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(batchDate#16, c2#15, 1)
      +- Window [collect_list(pv#18, 0, 0) windowspecdefinition(batchDate#16, c1#14, ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS c1List#28], [batchDate#16, c1#14]
         +- *Sort [batchDate#16 ASC NULLS FIRST, c1#14 ASC NULLS FIRST], false, 0
            +- Exchange hashpartitioning(batchDate#16, c1#14, 1)
               +- *Project [_1#7 AS c1#14, _2#8 AS c2#15, _3#9 AS batchDate#16, _4#10 AS id#17, _5#11 AS pv#18, _6#12 AS vv#19]
                  +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple6, true])._1, true) AS _1#7, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple6, true])._2, true) AS _2#8, assertnotnull(input[0, scala.Tuple6, true])._3 AS _3#9, assertnotnull(input[0, scala.Tuple6, true])._4 AS _4#10, assertnotnull(input[0, scala.Tuple6, true])._5 AS _5#11, assertnotnull(input[0, scala.Tuple6, true])._6 AS _6#12]
                     +- Scan ExternalRDDScan[obj#6]

agg2df.explain()
== Physical Plan ==
Window [collect_list(pv#18, 0, 0) windowspecdefinition(batchDate#16, c2#15, ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS c2List#60], [batchDate#16, c2#15]
+- *Sort [batchDate#16 ASC NULLS FIRST, c2#15 ASC NULLS FIRST], false, 0
   +- Window [collect_list(pv#18, 0, 0) windowspecdefinition(batchDate#16, c1#14, ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS c1List#50], [batchDate#16, c1#14]
      +- *Sort [batchDate#16 ASC NULLS FIRST, c1#14 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(batchDate#16, 1)
            +- *Project [_1#7 AS c1#14, _2#8 AS c2#15, _3#9 AS batchDate#16, _4#10 AS id#17, _5#11 AS pv#18, _6#12 AS vv#19]
               +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple6, true])._1, true) AS _1#7, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple6, true])._2, true) AS _2#8, assertnotnull(input[0, scala.Tuple6, true])._3 AS _3#9, assertnotnull(input[0, scala.Tuple6, true])._4 AS _4#10, assertnotnull(input[0, scala.Tuple6, true])._5 AS _5#11, assertnotnull(input[0, scala.Tuple6, true])._6 AS _6#12]
                  +- Scan ExternalRDDScan[obj#6]
like image 179
Apurba Pandey Avatar answered Sep 29 '22 12:09

Apurba Pandey