Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Avoid performance impact of a single partition mode in Spark window functions

Tags:

My question is triggered by the use case of calculating the differences between consecutive rows in a spark dataframe.

For example, I have:

>>> df.show() +-----+----------+ |index|      col1| +-----+----------+ |  0.0|0.58734024| |  1.0|0.67304325| |  2.0|0.85154736| |  3.0| 0.5449719| +-----+----------+ 

If I choose to calculate these using "Window" functions, then I can do that like so:

>>> winSpec = Window.partitionBy(df.index >= 0).orderBy(df.index.asc()) >>> import pyspark.sql.functions as f >>> df.withColumn('diffs_col1', f.lag(df.col1, -1).over(winSpec) - df.col1).show() +-----+----------+-----------+ |index|      col1| diffs_col1| +-----+----------+-----------+ |  0.0|0.58734024|0.085703015| |  1.0|0.67304325| 0.17850411| |  2.0|0.85154736|-0.30657548| |  3.0| 0.5449719|       null| +-----+----------+-----------+ 

Question: I explicitly partitioned the dataframe in a single partition. What is the performance impact of this and, if there is, why is that so and how could I avoid it? Because when I do not specify a partition, I get the following warning:

16/12/24 13:52:27 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation. 
like image 862
Ytsen de Boer Avatar asked Dec 24 '16 13:12

Ytsen de Boer


1 Answers

In practice performance impact will be almost the same as if you omitted partitionBy clause at all. All records will be shuffled to a single partition, sorted locally and iterated sequentially one by one.

The difference is only in the number of partitions created in total. Let's illustrate that with an example using simple dataset with 10 partitions and 1000 records:

df = spark.range(0, 1000, 1, 10).toDF("index").withColumn("col1", f.randn(42)) 

If you define frame without partition by clause

w_unpart = Window.orderBy(f.col("index").asc()) 

and use it with lag

df_lag_unpart = df.withColumn(     "diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1") ) 

there will be only one partition in total:

df_lag_unpart.rdd.glom().map(len).collect() 
[1000] 

Compared to that frame definition with dummy index (simplified a bit compared to your code:

w_part = Window.partitionBy(f.lit(0)).orderBy(f.col("index").asc()) 

will use number of partitions equal to spark.sql.shuffle.partitions:

spark.conf.set("spark.sql.shuffle.partitions", 11)  df_lag_part = df.withColumn(     "diffs_col1", f.lag("col1", 1).over(w_part) - f.col("col1") )  df_lag_part.rdd.glom().count() 
11 

with only one non-empty partition:

df_lag_part.rdd.glom().filter(lambda x: x).count() 
1 

Unfortunately there is no universal solution which can be used to address this problem in PySpark. This just an inherent mechanism of the implementation combined with distributed processing model.

Since index column is sequential you could generate artificial partitioning key with fixed number of records per block:

rec_per_block  = df.count() // int(spark.conf.get("spark.sql.shuffle.partitions"))  df_with_block = df.withColumn(     "block", (f.col("index") / rec_per_block).cast("int") ) 

and use it to define frame specification:

w_with_block = Window.partitionBy("block").orderBy("index")  df_lag_with_block = df_with_block.withColumn(     "diffs_col1", f.lag("col1", 1).over(w_with_block) - f.col("col1") ) 

This will use expected number of partitions:

df_lag_with_block.rdd.glom().count() 
11 

with roughly uniform data distribution (we cannot avoid hash collisions):

df_lag_with_block.rdd.glom().map(len).collect() 
[0, 180, 0, 90, 90, 0, 90, 90, 100, 90, 270] 

but with a number of gaps on the block boundaries:

df_lag_with_block.where(f.col("diffs_col1").isNull()).count() 
12 

Since boundaries are easy to compute:

from itertools import chain  boundary_idxs = sorted(chain.from_iterable(     # Here we depend on sequential identifiers     # This could be generalized to any monotonically increasing     # id by taking min and max per block     (idx - 1, idx) for idx in      df_lag_with_block.groupBy("block").min("index")         .drop("block").rdd.flatMap(lambda x: x)         .collect()))[2:]  # The first boundary doesn't carry useful inf. 

you can always select:

missing = df_with_block.where(f.col("index").isin(boundary_idxs)) 

and fill these separately:

# We use window without partitions here. Since number of records # will be small this won't be a performance issue # but will generate "Moving all data to a single partition" warning missing_with_lag = missing.withColumn(     "diffs_col1", f.lag("col1", 1).over(w_unpart) - f.col("col1") ).select("index", f.col("diffs_col1").alias("diffs_fill")) 

and join:

combined = (df_lag_with_block     .join(missing_with_lag, ["index"], "leftouter")     .withColumn("diffs_col1", f.coalesce("diffs_col1", "diffs_fill"))) 

to get desired result:

mismatched = combined.join(df_lag_unpart, ["index"], "outer").where(     combined["diffs_col1"] != df_lag_unpart["diffs_col1"] ) assert mismatched.count() == 0 
like image 102
zero323 Avatar answered Sep 29 '22 07:09

zero323