Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Median / quantiles within PySpark groupBy

I would like to calculate group quantiles on a Spark dataframe (using PySpark). Either an approximate or exact result would be fine. I prefer a solution that I can use within the context of groupBy / agg, so that I can mix it with other PySpark aggregate functions. If this is not possible for some reason, a different approach would be fine as well.

This question is related but does not indicate how to use approxQuantile as an aggregate function.

I also have access to the percentile_approx Hive UDF but I don't know how to use it as an aggregate function.

For the sake of specificity, suppose I have the following dataframe:

from pyspark import SparkContext import pyspark.sql.functions as f  sc = SparkContext()      df = sc.parallelize([     ['A', 1],     ['A', 2],     ['A', 3],     ['B', 4],     ['B', 5],     ['B', 6], ]).toDF(('grp', 'val'))  df_grp = df.groupBy('grp').agg(f.magic_percentile('val', 0.5).alias('med_val')) df_grp.show() 

Expected result is:

+----+-------+ | grp|med_val| +----+-------+ |   A|      2| |   B|      5| +----+-------+ 
like image 470
abeboparebop Avatar asked Oct 20 '17 08:10

abeboparebop


People also ask

How do you find the median in PySpark?

median() method calculates the median (middle value) of the given data set. This method also sorts the data in ascending order before calculating the median. Tip: The mathematical formula for Median is: Median = {(n + 1) / 2}th value, where n is the number of values in a set of data.

How do you find the median of a spark data frame?

To compute exact median for a group of rows we can use the build-in MEDIAN() function with a window function. However, not every database provides this function. In this case, we can compute the median using row_number() and count() in conjunction with a window function.

What does PySpark groupBy () do?

Similar to SQL GROUP BY clause, PySpark groupBy() function is used to collect the identical data into groups on DataFrame and perform count, sum, avg, min, max functions on the grouped data.

What is Stddev in PySpark?

stddev() in PySpark is used to return the standard deviation from a particular column in the DataFrame. Before that, we have to create PySpark DataFrame for demonstration.


2 Answers

I guess you don't need it anymore. But will leave it here for future generations (i.e. me next week when I forget).

from pyspark.sql import Window import pyspark.sql.functions as F  grp_window = Window.partitionBy('grp') magic_percentile = F.expr('percentile_approx(val, 0.5)')  df.withColumn('med_val', magic_percentile.over(grp_window)) 

Or to address exactly your question, this also works:

df.groupBy('grp').agg(magic_percentile.alias('med_val')) 

And as a bonus, you can pass an array of percentiles:

quantiles = F.expr('percentile_approx(val, array(0.25, 0.5, 0.75))') 

And you'll get a list in return.

like image 150
kael Avatar answered Sep 18 '22 13:09

kael


Since you have access to percentile_approx, one simple solution would be to use it in a SQL command:

from pyspark.sql import SQLContext sqlContext = SQLContext(sc)  df.registerTempTable("df") df2 = sqlContext.sql("select grp, percentile_approx(val, 0.5) as med_val from df group by grp") 
like image 44
Shaido Avatar answered Sep 18 '22 13:09

Shaido