Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark Window function on entire data frame

Consider a PySpark data frame. I would like to summarize the entire data frame, per column, and append the result for every row.

+-----+----------+-----------+
|index|      col1| col2      |
+-----+----------+-----------+
|  0.0|0.58734024|0.085703015|
|  1.0|0.67304325| 0.17850411|

Expected result

+-----+----------+-----------+-----------+-----------+-----------+-----------+
|index|      col1| col2      |  col1_min | col1_mean |col2_min   | col2_mean
+-----+----------+-----------+-----------+-----------+-----------+-----------+
|  0.0|0.58734024|0.085703015|  -5       | 2.3       |  -2       | 1.4 |
|  1.0|0.67304325| 0.17850411|  -5       | 2.3       |  -2       | 1.4 |

To my knowledge, I'll need Window function with the whole data frame as Window, to keep the result for each row (instead of, for example, do the stats separately then join back to replicate for each row)

My questions are:

  1. How to write Window without any partition nor order by?

    I know there is the standard Window with Partition and Order, but not the one taking everything as 1 single partition

    w = Window.partitionBy("col1", "col2").orderBy(desc("col1"))
    df = df.withColumn("col1_mean", mean("col1").over(w)))
    

    How would I write a Window with everything as one partition?

  2. Any way to write dynamically for all columns?

    Let's say I have 500 columns, it does not look great to write repeatedly.

    df = (df
        .withColumn("col1_mean", mean("col1").over(w)))
        .withColumn("col1_min", min("col2").over(w))
        .withColumn("col2_mean", mean().over(w))
        .....
    )
    

    Let's assume I want multiple stats for each column, so each colx will spawn colx_min, colx_max, colx_mean.

like image 994
Kenny Avatar asked Feb 26 '20 16:02

Kenny


People also ask

How does window function work in PySpark?

PySpark Window function performs statistical operations such as rank, row number, etc. on a group, frame, or collection of rows and returns results for each row individually. It is also popularly growing to perform data transformations.

What is Ntile PySpark?

ntile (n)[source] Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window partition. For example, if n is 4, the first quarter of the rows will get value 1, the second quarter will get 2, the third quarter will get 3, and the last quarter will get 4.

What does .collect do in PySpark?

Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.

How do you coalesce in PySpark?

The Data frame coalesce can be used in the same way by using the. RDD converts it to RDD and gets the NUM Partitions. Let us check some more examples for Coalesce function. Let us try to increase the partition using the coalesce function; we will try to increase the partition from the default partition.


1 Answers

Instead of using window you can achieve the same with a custom aggregation in combination with cross join:

import pyspark.sql.functions as F
from pyspark.sql.functions import broadcast
from itertools import chain

df = spark.createDataFrame([
  [1, 2.3, 1],
  [2, 5.3, 2],
  [3, 2.1, 4],
  [4, 1.5, 5]
], ["index", "col1", "col2"])

agg_cols = [(
             F.min(c).alias("min_" + c), 
             F.max(c).alias("max_" + c), 
             F.mean(c).alias("mean_" + c)) 

  for c in df.columns if c.startswith('col')]

stats_df = df.agg(*list(chain(*agg_cols)))

# there is no performance impact from crossJoin since we have only one row on the right table which we broadcast (most likely Spark will broadcast it anyway)
df.crossJoin(broadcast(stats_df)).show() 

# +-----+----+----+--------+--------+---------+--------+--------+---------+
# |index|col1|col2|min_col1|max_col1|mean_col1|min_col2|max_col2|mean_col2|
# +-----+----+----+--------+--------+---------+--------+--------+---------+
# |    1| 2.3|   1|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    2| 5.3|   2|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    3| 2.1|   4|     1.5|     5.3|      2.8|       1|       5|      3.0|
# |    4| 1.5|   5|     1.5|     5.3|      2.8|       1|       5|      3.0|
# +-----+----+----+--------+--------+---------+--------+--------+---------+

Note1: Using broadcast we will avoid shuffling since the broadcasted df will be send to all the executors.

Note2: with chain(*agg_cols) we flatten the list of tuples which we created in the previous step.

UPDATE:

Here is the execution plan for the above program:

== Physical Plan ==
*(3) BroadcastNestedLoopJoin BuildRight, Cross
:- *(3) Scan ExistingRDD[index#196L,col1#197,col2#198L]
+- BroadcastExchange IdentityBroadcastMode, [id=#274]
   +- *(2) HashAggregate(keys=[], functions=[finalmerge_min(merge min#233) AS min(col1#197)#202, finalmerge_max(merge max#235) AS max(col1#197)#204, finalmerge_avg(merge sum#238, count#239L) AS avg(col1#197)#206, finalmerge_min(merge min#241L) AS min(col2#198L)#208L, finalmerge_max(merge max#243L) AS max(col2#198L)#210L, finalmerge_avg(merge sum#246, count#247L) AS avg(col2#198L)#212])
      +- Exchange SinglePartition, [id=#270]
         +- *(1) HashAggregate(keys=[], functions=[partial_min(col1#197) AS min#233, partial_max(col1#197) AS max#235, partial_avg(col1#197) AS (sum#238, count#239L), partial_min(col2#198L) AS min#241L, partial_max(col2#198L) AS max#243L, partial_avg(col2#198L) AS (sum#246, count#247L)])
            +- *(1) Project [col1#197, col2#198L]
               +- *(1) Scan ExistingRDD[index#196L,col1#197,col2#198L]

Here we see a BroadcastExchange of a SinglePartition which is broadcasting one single row since stats_df can fit into a SinglePartition. Therefore the data being shuffled here is only one row (the minimum possible).

like image 63
abiratsis Avatar answered Oct 10 '22 12:10

abiratsis