Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Retrieve top n in each group of a DataFrame in pyspark

There's a DataFrame in pyspark with data as below:

user_id object_id score user_1  object_1  3 user_1  object_1  1 user_1  object_2  2 user_2  object_1  5 user_2  object_2  2 user_2  object_2  6 

What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:

user_id object_id score user_1  object_1  3 user_1  object_2  2 user_2  object_2  6 user_2  object_1  5 

I'm really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!

like image 557
KAs Avatar asked Jul 15 '16 13:07

KAs


People also ask

How do you get top 5 values in PySpark?

In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).

How do you filter the first row in PySpark?

Using the PySpark filter(), just select row == 1, which returns just the first row of each group. Finally, if a row column is not needed, just drop it.

How do I get the first row of Spark DataFrame?

The head() operator returns the first row of the Spark Dataframe. If you need first n records, then you can use head(n).


1 Answers

I believe you need to use window functions to attain the rank of each row based on user_id and score, and subsequently filter your results to only keep the first two values.

from pyspark.sql.window import Window from pyspark.sql.functions import rank, col  window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())  df.select('*', rank().over(window).alias('rank'))    .filter(col('rank') <= 2)    .show()  #+-------+---------+-----+----+ #|user_id|object_id|score|rank| #+-------+---------+-----+----+ #| user_1| object_1|    3|   1| #| user_1| object_2|    2|   2| #| user_2| object_2|    6|   1| #| user_2| object_1|    5|   2| #+-------+---------+-----+----+ 

In general, the official programming guide is a good place to start learning Spark.

Data

rdd = sc.parallelize([("user_1",  "object_1",  3),                        ("user_1",  "object_2",  2),                        ("user_2",  "object_1",  5),                        ("user_2",  "object_2",  2),                        ("user_2",  "object_2",  6)]) df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"]) 
like image 121
mtoto Avatar answered Sep 22 '22 05:09

mtoto