Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does Spark not push down a filter to before a groupBy with collect_list?

Tags:

Consider this example:

import pyspark
import pyspark.sql.functions as f


with pyspark.SparkContext(conf=pyspark.SparkConf().setMaster('local[*]')) as sc:
    spark = pyspark.sql.SQLContext(sc)

    df = spark.createDataFrame([
        [2020, 1, 1, 1.0],
        [2020, 1, 2, 2.0],
        [2020, 1, 3, 3.0],
    ], schema=['year', 'id', 't', 'value'])

    df = df.groupBy(['year', 'id']).agg(f.collect_list('value'))
    df = df.where(f.col('year') == 2020)
    df.explain()

which yields the following plan

== Physical Plan ==
*(2) Filter (isnotnull(year#0L) AND (year#0L = 2020))
+- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[collect_list(value#3, 0, 0)])
   +- Exchange hashpartitioning(year#0L, id#1L, 200), true, [id=#23]
      +- ObjectHashAggregate(keys=[year#0L, id#1L], functions=[partial_collect_list(value#3, 0, 0)])
         +- *(1) Project [year#0L, id#1L, value#3]
            +- *(1) Scan ExistingRDD[year#0L,id#1L,t#2L,value#3]

I would like Spark to push the filter year = 2020 to before the hashpartitioning. If the aggregation function is sum, Spark does it, but it does not do it for collect_list.

Any ideas as to why this is not the case, and whether there is a way to address this?

The reason for doing this is that without a filter pushdown, the statement for 3 years (e.g. year IN (2020, 2019, 2018) performs a shuffle between them. Also, I need to express the filter after the groupBy in code.

More importantly, I am trying to understand why Spark does not push the filter down for some aggregations, but it does for others.

like image 440
Jorge Leitao Avatar asked Jul 04 '20 06:07

Jorge Leitao


1 Answers

Let's have a look at the aggregate function that you are using.

collect_list

From the doc below -

/**
   * Aggregate function: returns a list of objects with duplicates.
   *
   * @note The function is non-deterministic because the order of collected results depends
   * on the order of the rows which may be non-deterministic after a shuffle.
   *
   * @group agg_funcs
   * @since 1.6.0
   */
  def collect_list(columnName: String): Column = collect_list(Column(columnName))

collect_list is a non-deterministic operation and its result depends on the order of rows.

Now look at the Optimizer.scala#PushPredicateThroughNonJoin,

// SPARK-13473: We can't push the predicate down when the underlying projection output non-
    // deterministic field(s).  Non-deterministic expressions are essentially stateful. This
    // implies that, for a given input row, the output are determined by the expression's initial
    // state and all the input rows processed before. In another word, the order of input rows
    // matters for non-deterministic expressions, while pushing down predicates changes the order.
    // This also applies to Aggregate.

Since the above operation is non-deterministic i.e. the result is dependent on the order of rows of underlying dataframe, spark can't push the predicate because it changes the order of rows.

like image 53
Som Avatar answered Sep 20 '22 11:09

Som