Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GroupBy column and filter rows with maximum value in Pyspark

I am almost certain this has been asked before, but a search through stackoverflow did not answer my question. Not a duplicate of [2] since I want the maximum value, not the most frequent item. I am new to pyspark and trying to do something really simple: I want to groupBy column "A" and then only keep the row of each group that has the maximum value in column "B". Like this:

df_cleaned = df.groupBy("A").agg(F.max("B")) 

Unfortunately, this throws away all other columns - df_cleaned only contains the columns "A" and the max value of B. How do I instead keep the rows? ("A", "B", "C"...)

like image 801
Thomas Avatar asked Feb 16 '18 15:02

Thomas


People also ask

How do you select top 5 rows in PySpark?

In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).

How can I get maximum salary in Spark?

Using the Spark filter(), just select row == 1, which returns the maximum salary of each group. Finally, if a row column is not needed, just drop it.

How can I show more than 20 rows in Spark?

By default Spark with Scala, Java, or with Python (PySpark), fetches only 20 rows from DataFrame show() but not all rows and the column value is truncated to 20 characters, In order to fetch/display more than 20 rows and column full value from Spark/PySpark DataFrame, you need to pass arguments to the show() method.


1 Answers

You can do this without a udf using a Window.

Consider the following example:

import pyspark.sql.functions as f data = [     ('a', 5),     ('a', 8),     ('a', 7),     ('b', 1),     ('b', 3) ] df = sqlCtx.createDataFrame(data, ["A", "B"]) df.show() #+---+---+ #|  A|  B| #+---+---+ #|  a|  5| #|  a|  8| #|  a|  7| #|  b|  1| #|  b|  3| #+---+---+ 

Create a Window to partition by column A and use this to compute the maximum of each group. Then filter out the rows such that the value in column B is equal to the max.

from pyspark.sql import Window w = Window.partitionBy('A') df.withColumn('maxB', f.max('B').over(w))\     .where(f.col('B') == f.col('maxB'))\     .drop('maxB')\     .show() #+---+---+ #|  A|  B| #+---+---+ #|  a|  8| #|  b|  3| #+---+---+ 

Or equivalently using pyspark-sql:

df.registerTempTable('table') q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB" sqlCtx.sql(q).show() #+---+---+ #|  A|  B| #+---+---+ #|  b|  3| #|  a|  8| #+---+---+ 
like image 138
pault Avatar answered Oct 11 '22 18:10

pault