There's a DataFrame in pyspark with data as below:
user_id object_id score user_1 object_1 3 user_1 object_1 1 user_1 object_2 2 user_2 object_1 5 user_2 object_2 2 user_2 object_2 6
What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:
user_id object_id score user_1 object_1 3 user_1 object_2 2 user_2 object_2 6 user_2 object_1 5
I'm really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!
In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).
Using the PySpark filter(), just select row == 1, which returns just the first row of each group. Finally, if a row column is not needed, just drop it.
The head() operator returns the first row of the Spark Dataframe. If you need first n records, then you can use head(n).
I believe you need to use window functions to attain the rank of each row based on user_id
and score
, and subsequently filter your results to only keep the first two values.
from pyspark.sql.window import Window from pyspark.sql.functions import rank, col window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc()) df.select('*', rank().over(window).alias('rank')) .filter(col('rank') <= 2) .show() #+-------+---------+-----+----+ #|user_id|object_id|score|rank| #+-------+---------+-----+----+ #| user_1| object_1| 3| 1| #| user_1| object_2| 2| 2| #| user_2| object_2| 6| 1| #| user_2| object_1| 5| 2| #+-------+---------+-----+----+
In general, the official programming guide is a good place to start learning Spark.
rdd = sc.parallelize([("user_1", "object_1", 3), ("user_1", "object_2", 2), ("user_2", "object_1", 5), ("user_2", "object_2", 2), ("user_2", "object_2", 6)]) df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With