Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark data skewness with Window Functions

I have a huge PySpark dataframe and I'm doing a series of Window functions over partitions defined by my key.

The issue with the key is, my partitions gets skewed by this and results in Event Timeline that looks something like this,

enter image description here

I know that I can use salting technique to solve this issue when I'm doing a join. But how can I solve this issue when I'm using Window functions?

I'm using functions like lag, lead etc in the Window functions. I can't do the process with salted key, because I'll get wrong results.

How to solve skewness in this case?

I'm looking for a dynamic way of repartitioning my dataframe without skewness.

Updates based on answer from @jxc

I tried creating a sample df and tried running code over that,

df = pd.DataFrame()
df['id'] = np.random.randint(1, 1000, size=150000)
df['id'] = df['id'].map(lambda x: 100 if x % 2 == 0 else x)
df['timestamp'] = pd.date_range(start=pd.Timestamp('2020-01-01'), periods=len(df), freq='60s')
sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)
w = Window.partitionBy("id").orderBy("timestamp")

sdf = sdf.withColumn("new_col", F.lag("amt").over(w) + F.lead("amt").over(w))
x = sdf.toPandas()

This gave me a event timeline like this,

enter image description here

I tried the code from @jxc's answer,

sdf = sc.createDataFrame(df)
sdf = sdf.withColumn("amt", F.rand()*100)

N = 24*3600*365*2
sdf_1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))

w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
w2 = Window.partitionBy('id', 'pid')

sdf_2 = sdf_1.select(
    '*',
    F.count('*').over(w2).alias('cnt'),
    F.row_number().over(w1).alias('rn'),
    (F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_val')
)

sdf_3 = sdf_2.filter('rn in (1, 2, cnt-1, cnt)') \
    .withColumn('new_val', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
    .filter('rn in (1,cnt)')
    
df_new = sdf_2.filter('rn not in (1,cnt)').union(sdf_3)

x = df_new.toPandas()

I ended up one additional stage and the event timeline looked more skewed,

enter image description here

Also the run time is increased by a bit with new code

like image 851
Sreeram TP Avatar asked Nov 23 '20 11:11

Sreeram TP


People also ask

How does Spark calculate data skewness?

Let's see how to identify skew and how to identify and mitigate skew in your data. Step 1: Read data from the Table into a data frame. Step 2: Find the number of rows per partition. Here, the function spark_partition_id() returns the current partition id, by plotting the result graphically you will notice the skew.


1 Answers

To process a large partition, you can try split it based on the orderBy column(most likely a numeric column or date/timestamp column which can be converted into numeric) so that all new sub-partitions maintain the correct order of rows. process rows with the new partitioner and for calculation using lag and lead functions, only rows around the boundary between sub-partitions need to be post-processed. (Below also discussed how to merge small partitions in task-2)

Use your example sdf and assume we have the following WinSpec and a simple aggregate function:

w = Window.partitionBy('id').orderBy('timestamp')
df.withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w))

Task-1: split large partitions:

Try the following:

  1. select a N to split timestamp and set up an additional partitionBy column pid (using ceil, int, floor etc.):

    # N to cover 35-days' intervals
    N = 24*3600*35
    df1 = sdf.withColumn('pid', F.ceil(F.unix_timestamp('timestamp')/N))
    
  2. add pid into partitionBy(see w1), then calaulte row_number(), lag() and lead() over w1. find also number of rows (cnt) in each new partition to help identify the end of partitions (rn == cnt). the resulting new_val will be fine for majority of rows except those on the boundaries of each partition.

    w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')
    w2 = Window.partitionBy('id', 'pid')
    
    df2 = df1.select(
        '*',
        F.count('*').over(w2).alias('cnt'),
        F.row_number().over(w1).alias('rn'),
        (F.lag('amt',1).over(w1) + F.lead('amt',1).over(w1)).alias('new_amt')
    )
    

    Below is an example df2 showing the boundary rows.

    enter image description here

  3. process the boundary: select rows which are on the boundaries rn in (1, cnt) plus those which have values used in the calculation rn in (2, cnt-1), do the same calculation of new_val over w and save result for boundary rows only.

    df3 = df2.filter('rn in (1, 2, cnt-1, cnt)') \
        .withColumn('new_amt', F.lag('amt',1).over(w) + F.lead('amt',1).over(w)) \
        .filter('rn in (1,cnt)')
    

    Below shows the resulting df3 from the above df2

    enter image description here

  4. merge df3 back to df2 to update boundary rows rn in (1,cnt)

    df_new = df2.filter('rn not in (1,cnt)').union(df3)
    

    Below screenshot shows the final df_new around the boundary rows:

    enter image description here

    # drop columns which are used to implement logic only
    df_new = df_new.drop('cnt', 'rn')
    

Some Notes:

  1. the following 3 WindowSpec are defined:

    w = Window.partitionBy('id').orderBy('timestamp')          <-- fix boundary rows
    w1 = Window.partitionBy('id', 'pid').orderBy('timestamp')  <-- calculate internal rows
    w2 = Window.partitionBy('id', 'pid')                       <-- find #rows in a partition
    

    note: strictly, we'd better use the following w to fix boundary rows to avoid issues with tied timestamp around the boundaries.

    w = Window.partitionBy('id').orderBy('pid', 'rn')          <-- fix boundary rows
    
  2. if you know which partitions are skewed, just divide them and skip others. the existing method might split a small partition into 2 or even more if they are sparsely distributed

    df1 = df.withColumn('pid', F.when(F.col('id').isin('a','b'), F.ceil(F.unix_timestamp('timestamp')/N)).otherwise(1))
    

    If for each partition, you can retrieve count(number of rows) and min_ts=min(timestamp), then try something more dynamically for pid(below M is the threshold number of rows to split):

    F.expr(f"IF(count>{M}, ceil((unix_timestamp(timestamp)-unix_timestamp(min_ts))/{N}), 1)")
    

    note: for skewness inside a partition, will requires more complex functions to generate pid.

  3. if only lag(1) function is used, just post-process left boundaries, filter by rn in (1, cnt) and update only rn == 1

    df3 = df1.filter('rn in (1, cnt)') \
        .withColumn('new_amt', F.lag('amt',1).over(w)) \
        .filter('rn = 1')
    

    similar to lead function when we need only to fix right boundaries and update rn == cnt

  4. if only lag(2) is used, then filter and update more rows with df3:

    df3 = df1.filter('rn in (1, 2, cnt-1, cnt)') \
        .withColumn('new_amt', F.lag('amt',2).over(w)) \
        .filter('rn in (1,2)')
    

    You can extend the same method to mixed cases with both lag and lead having different offset.

Task-2: merge small partitions:

Based on the number of records in a partition count, we can set up an threshold M so that if count>M, the id holds its own partition, otherwise we merge partitions so that #of total records is less than M (below method has a edging case of 2*M-2).

M = 20000

# create pandas df with columns `id`, `count` and `f`, sort rows so that rows with count>=M are located on top
d2 = pd.DataFrame([ e.asDict() for e in sdf.groupby('id').count().collect() ]) \
    .assign(f=lambda x: x['count'].lt(M)) \
    .sort_values('f')    

# add pid column to merge smaller partitions but the total row-count in partition should be less than or around M 
# potentially there could be at most `2*M-2` records for the same pid, to make sure strictly count<M, use a for-loop to iterate d1 and set pid:
d2['pid'] = (d2.mask(d2['count'].gt(M),M)['count'].shift(fill_value=0).cumsum()/M).astype(int)

# add pid to sdf. In case join is too heavy, try using Map
sdf_1 = sdf.join(spark.createDataFrame(d2).alias('d2'), ["id"]) \
    .select(sdf["*"], F.col("d2.pid"))

# check pid: # of records and # of distinct ids
sdf_1.groupby('pid').agg(F.count('*').alias('count'), F.countDistinct('id').alias('cnt_ids')).orderBy('pid').show()
+---+-----+-------+                                                             
|pid|count|cnt_ids|
+---+-----+-------+
|  0|74837|      1|
|  1|20036|    133|
|  2|20052|    134|
|  3|20010|    133|
|  4|15065|    100|
+---+-----+-------+

Now, the new Window should be partitioned by pid alone and move id to orderBy, see below:

w3 = Window.partitionBy('pid').orderBy('id','timestamp')

customize lag/lead functions based on the above w3 WinSpec, and then calculate new_val:

lag_w3  = lambda col,n=1: F.when(F.lag('id',n).over(w3) == F.col('id'), F.lag(col,n).over(w3))
lead_w3 = lambda col,n=1: F.when(F.lead('id',n).over(w3) == F.col('id'), F.lead(col,n).over(w3))

sdf_new = sdf_1.withColumn('new_val', lag_w3('amt',1) + lead_w3('amt',1))
like image 91
jxc Avatar answered Oct 23 '22 05:10

jxc