What I want to do is given a DataFrame, take top n elements according to some specified column. The top(self, num) in RDD API is exactly what I want. I wonder if there is equivalent API in DataFrame world ?
My first attempt is the following
def retrieve_top_n(df, n):
# assume we want to get most popular n 'key' in DataFrame
return df.groupBy('key').count().orderBy('count', ascending=False).limit(n).select('key')
However, I've realized that this results in non-deterministic behavior (I don't know the exact reason but I guess limit(n) doesn't guarantee which n to take)
Below there are different ways how are you able to create the PySpark DataFrame: In the give implementation, we will create pyspark dataframe using an inventory of rows. For this, we are providing the values to each variable (feature) in each row and added to the dataframe object.
In summary, you can select/find the top N rows for each group in PySpark DataFrame by partitioning the data by group using Window.partitionBy (), sort the partition data per each group, add row_number () to the sorted data and finally filter to get the top n records. Happy Learning !!
If you are using PySpark, you usually get the First N records and Convert the PySpark DataFrame to Pandas Returns top N row. Scala – Return Array [Row]. Note: take (), first () and head () actions internally calls limit () transformation and finally calls collect () action to collect the data. 2. Show Last N Rows in Spark/PySpark
Following are actions that Get’s top/first n rows from DataFrame, except show (), most of all actions returns list of class Row for PySpark and Array [Row] for Spark with Scala. If you are using PySpark, you usually get the First N records and Convert the PySpark DataFrame to Pandas Returns top N row. Scala – Return Array [Row].
import numpy as np
def sample_df(num_records):
def data():
np.random.seed(42)
while True:
yield int(np.random.normal(100., 80.))
data_iter = iter(data())
df = sc.parallelize((
(i, next(data_iter)) for i in range(int(num_records))
)).toDF(('index', 'key_col'))
return df
sample_df(1e3).show(n=5)
+-----+-------+
|index|key_col|
+-----+-------+
| 0| 139|
| 1| 88|
| 2| 151|
| 3| 221|
| 4| 81|
+-----+-------+
only showing top 5 rows
from pyspark.sql import Window
from pyspark.sql import functions
def top_df_0(df, key_col, K):
"""
Using window functions. Handles ties OK.
"""
window = Window.orderBy(functions.col(key_col).desc())
return (df
.withColumn("rank", functions.rank().over(window))
.filter(functions.col('rank') <= K)
.drop('rank'))
def top_df_1(df, key_col, K):
"""
Using limit(K). Does NOT handle ties appropriately.
"""
return df.orderBy(functions.col(key_col).desc()).limit(K)
def top_df_2(df, key_col, K):
"""
Using limit(k) and then filtering. Handles ties OK."
"""
num_records = df.count()
value_at_k_rank = (df
.orderBy(functions.col(key_col).desc())
.limit(k)
.select(functions.min(key_col).alias('min'))
.first()['min'])
return df.filter(df[key_col] >= value_at_k_rank)
The function called top_df_1
is similar to the one you originally implemented. The reason it gives you non-deterministic behavior is because it cannot handle ties nicely. This may be an OK thing to do if you have lots of data and are only interested in an approximate answer for the sake of performance.
For benchmarking use a Spark DF with 4 million entries and define a convenience function:
NUM_RECORDS = 4e6
test_df = sample_df(NUM_RECORDS).cache()
def show(func, df, key_col, K):
func(df, key_col, K).select(
functions.max(key_col),
functions.min(key_col),
functions.count(key_col)
).show()
Let's see the verdict:
%timeit show(top_df_0, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 1.62 s per loop
%timeit show(top_df_1, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 100|
+------------+------------+--------------+
1 loops, best of 3: 252 ms per loop
%timeit show(top_df_2, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
| 502| 420| 108|
+------------+------------+--------------+
1 loops, best of 3: 725 ms per loop
(Note that top_df_0
and top_df_2
have 108 entries in the top 100. This is due to the presence of tied entries for the 100th best. The top_df_1
implementation is ignoring the tied entries.).
If you want an exact answer go with top_df_2
(it is about 2x better than top_df_0
). If you want another x2 in performance and are OK with an approximate answer go with top_df_1
.
Options:
1) Use pyspark sql row_number within a window function - relevant SO: spark dataframe grouping, sorting, and selecting top rows for a set of columns
2) convert ordered df to rdd and use the top function there (hint: this doesn't appear to actually maintain ordering from my quick test, but YMMV)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With