I have a Spark SQL DataFrame:
user1 item1 rating1
user1 item2 rating2
user1 item3 rating3
user2 item1 rating4
...
How to group by user and then return TopN items from every group using Scala?
Similarity code using Python:
df.groupby("user").apply(the_func_get_TopN)
You can use rank window function as follows
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{rank, desc}
val n: Int = ???
// Window definition
val w = Window.partitionBy($"user").orderBy(desc("rating"))
// Filter
df.withColumn("rank", rank.over(w)).where($"rank" <= n)
If you don't care about ties then you can replace rank with row_number
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With