Consider this example:
import pyspark
import pyspark.sql.functions as f
with pyspark.SparkContext(conf=pyspark.SparkConf().setMaster('local[*]')) as sc:
spark = pyspark.sql.SQLContext(sc)
df = spark.createDataFrame([
[2020, 1, 1, 1.0],
[2020, 1, 2, 2.0],
[2020, 1, 3, 3.0],
], schema=['year', 'id', 't', 'value'])
df = df.groupBy(['year', 'id']).agg(f.collect_list('value'))
df = df.where(f.col('year') == 2020)
df.explain()
which yields the following plan
== Physical Plan ==
*(2) Filter (isnotnull(year#0L) AND (year#0L = 2020))
+- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[collect_list(value#3, 0, 0)])
+- Exchange hashpartitioning(year#0L, id#1L, 200), true, [id=#23]
+- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[partial_collect_list(value#3, 0, 0)])
+- *(1) Project [year#0L, id#1L, value#3]
+- *(1) Scan ExistingRDD[year#0L,id#1L,t#2L,value#3]
I would like Spark to push the filter year = 2020
to before the hashpartitioning
. If the aggregation function is sum
, Spark does it, but it does not do it for collect_list
.
Any ideas as to why this is not the case, and whether there is a way to address this?
The reason for doing this is that without a filter pushdown, the statement for 3 years (e.g. year IN (2020, 2019, 2018)
performs a shuffle between them. Also, I need to express the filter after the groupBy in code.
More importantly, I am trying to understand why Spark does not push the filter down for some aggregations, but it does for others.
Let's have a look at the aggregate function that you are using.
collect_list
From the doc below -
/**
* Aggregate function: returns a list of objects with duplicates.
*
* @note The function is non-deterministic because the order of collected results depends
* on the order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
*/
def collect_list(columnName: String): Column = collect_list(Column(columnName))
collect_list is a non-deterministic operation and its result depends on the order of rows.
Now look at the Optimizer.scala#PushPredicateThroughNonJoin,
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
// implies that, for a given input row, the output are determined by the expression's initial
// state and all the input rows processed before. In another word, the order of input rows
// matters for non-deterministic expressions, while pushing down predicates changes the order.
// This also applies to Aggregate.
Since the above operation is non-deterministic i.e. the result is dependent on the order of rows of underlying dataframe, spark can't push the predicate because it changes the order of rows.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With