Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark: Prevent shuffle/exchange when joining two identically partitioned dataframes

I have two dataframes df1 and df2 and I want to join these tables many times on a high cardinality field called visitor_id. I would like to perform only one initial shuffle and have all the joins take place without shuffling/exchanging data between spark executors.

To do so, I have created another column called visitor_partition that consistently assigns each visitor_id a random value between [0, 1000). I have used a custom partitioner to ensure that the df1 and df2 are exactly partitioned such that each partition contains exclusively rows from one value of visitor_partition. This initial repartition is the only time I want to shuffle the data.

I have saved each dataframe to parquet in s3, paritioning by visitor partition -- for each data frame, this creates 1000 files organized in df1/visitor_partition=0, df1/visitor_partition=1...df1/visitor_partition=999.

Now I load each dataframe from the parquet and register them as tempviews via df1.createOrReplaceTempView('df1') (and the same thing for df2) and then run the following query

SELECT
   ...
FROM
  df1 FULL JOIN df1 ON
    df1.visitor_partition = df2.visitor_partition AND
    df1.visitor_id = df2.visitor_id

In theory, the query execution planner should realize that no shuffling is necessary here. E.g., a single executor could load in data from df1/visitor_partition=1 and df2/visitor_partition=2 and join the rows in there. However, in practice spark 2.4.4's query planner performs a full data shuffle here.

Is there some way I can prevent this shuffle from taking place?

like image 791
conradlee Avatar asked Nov 25 '19 15:11

conradlee


1 Answers

You can use the bucketBy method of the DataFrameWriter (other documentation).

In the following example, the value of the column VisitorID will be hashed into 500 buckets. Normally, for the join Spark would perform an exchange phase based on the hash on the VisitorID. However, in this case you already have the data pre-partitioned with the hash.

inputRdd = sc.parallelize(list((i, i%200) for i in range(0,1000000)))

schema = StructType([StructField("VisitorID", IntegerType(), True),
                    StructField("visitor_partition", IntegerType(), True)])

inputdf = inputRdd.toDF(schema)

inputdf.write.bucketBy(500, "VisitorID").saveAsTable("bucketed_table")

inputDf1 = spark.sql("select * from bucketed_table")
inputDf2 = spark.sql("select * from bucketed_table")
inputDf3 = inputDf1.alias("df1").join(inputDf2.alias("df2"), col("df1.VisitorID") == col("df2.VisitorID"))

Sometimes Spark query optimizer still choose broadcast exchange, so for our example, let's disable auto broadcasting

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

The physical plan would look as follow:

== Physical Plan ==
*(3) SortMergeJoin [VisitorID#351], [VisitorID#357], Inner
:- *(1) Sort [VisitorID#351 ASC NULLS FIRST], false, 0
:  +- *(1) Project [VisitorID#351, visitor_partition#352]
:     +- *(1) Filter isnotnull(VisitorID#351)
:        +- *(1) FileScan parquet default.bucketed_6[VisitorID#351,visitor_partition#352] Batched: true, DataFilters: [isnotnull(VisitorID#351)], Format: Parquet, Location: InMemoryFileIndex[dbfs:/user/hive/warehouse/bucketed_6], PartitionFilters: [], PushedFilters: [IsNotNull(VisitorID)], ReadSchema: struct<VisitorID:int,visitor_partition:int>, SelectedBucketsCount: 500 out of 500
+- *(2) Sort [VisitorID#357 ASC NULLS FIRST], false, 0
   +- *(2) Project [VisitorID#357, visitor_partition#358]
      +- *(2) Filter isnotnull(VisitorID#357)
         +- *(2) FileScan parquet default.bucketed_6[VisitorID#357,visitor_partition#358] Batched: true, DataFilters: [isnotnull(VisitorID#357)], Format: Parquet, Location: InMemoryFileIndex[dbfs:/user/hive/warehouse/bucketed_6], PartitionFilters: [], PushedFilters: [IsNotNull(VisitorID)], ReadSchema: struct<VisitorID:int,visitor_partition:int>, SelectedBucketsCount: 500 out of 500

Doing something like:

inputdf.write.partitionBy("visitor_partition").saveAsTable("partitionBy_2")

Creates indeed the structure with a folder for each partition. But it's not working since the Spark join is based on the hash and is not able to leverage your custom structure.

Edit: I misunderstood your example. I believe you were talking about something like partitionBy, not repartition as mentioned in the previous version.

like image 52
LizardKing Avatar answered Sep 19 '22 05:09

LizardKing