Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

collect_list by preserving order based on another variable

I am trying to create a new column of lists in Pyspark using a groupby aggregation on existing set of columns. An example input data frame is provided below:

------------------------ id | date        | value ------------------------ 1  |2014-01-03   | 10  1  |2014-01-04   | 5 1  |2014-01-05   | 15 1  |2014-01-06   | 20 2  |2014-02-10   | 100    2  |2014-03-11   | 500 2  |2014-04-15   | 1500 

The expected output is:

id | value_list ------------------------ 1  | [10, 5, 15, 20] 2  | [100, 500, 1500] 

The values within a list are sorted by the date.

I tried using collect_list as follows:

from pyspark.sql import functions as F ordered_df = input_df.orderBy(['id','date'],ascending = True) grouped_df = ordered_df.groupby("id").agg(F.collect_list("value")) 

But collect_list doesn't guarantee order even if I sort the input data frame by date before aggregation.

Could someone help on how to do aggregation by preserving the order based on a second (date) variable?

like image 552
Ravi Avatar asked Oct 05 '17 07:10

Ravi


People also ask

Does Collect_list preserve order?

Does it mean collect_list also maintains the order? In your code, you sort the entire dataset before collect_list() so yes.

How do you use PySpark collect?

PySpark Collect() – Retrieve data from DataFrame. Collect() is the function, operation for RDD or Dataframe that is used to retrieve the data from the Dataframe. It is used useful in retrieving all the elements of the row from each partition in an RDD and brings that over the driver node/program.


1 Answers

from pyspark.sql import functions as F from pyspark.sql import Window  w = Window.partitionBy('id').orderBy('date')  sorted_list_df = input_df.withColumn(             'sorted_list', F.collect_list('value').over(w)         )\         .groupBy('id')\         .agg(F.max('sorted_list').alias('sorted_list')) 

Window examples provided by users often don't really explain what is going on so let me dissect it for you.

As you know, using collect_list together with groupBy will result in an unordered list of values. This is because depending on how your data is partitioned, Spark will append values to your list as soon as it finds a row in the group. The order then depends on how Spark plans your aggregation over the executors.

A Window function allows you to control that situation, grouping rows by a certain value so you can perform an operation over each of the resultant groups:

w = Window.partitionBy('id').orderBy('date') 
  • partitionBy - you want groups/partitions of rows with the same id
  • orderBy - you want each row in the group to be sorted by date

Once you have defined the scope of your Window - "rows with the same id, sorted by date" -, you can use it to perform an operation over it, in this case, a collect_list:

F.collect_list('value').over(w) 

At this point you created a new column sorted_list with an ordered list of values, sorted by date, but you still have duplicated rows per id. To trim out the duplicated rows you want to groupBy id and keep the max value in for each group:

.groupBy('id')\ .agg(F.max('sorted_list').alias('sorted_list')) 
like image 106
TMichel Avatar answered Oct 01 '22 17:10

TMichel