Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find maximum row per group in Spark DataFrame

I'm trying to use Spark dataframes instead of RDDs since they appear to be more high-level than RDDs and tend to produce more readable code.

In a 14-nodes Google Dataproc cluster, I have about 6 millions names that are translated to ids by two different systems: sa and sb. Each Row contains name, id_sa and id_sb. My goal is to produce a mapping from id_sa to id_sb such that for each id_sa, the corresponding id_sb is the most frequent id among all names attached to id_sa.

Let's try to clarify with an example. If I have the following rows:

[Row(name='n1', id_sa='a1', id_sb='b1'),  Row(name='n2', id_sa='a1', id_sb='b2'),  Row(name='n3', id_sa='a1', id_sb='b2'),  Row(name='n4', id_sa='a2', id_sb='b2')] 

My goal is to produce a mapping from a1 to b2. Indeed, the names associated to a1 are n1, n2 and n3, which map respectively to b1, b2 and b2, so b2 is the most frequent mapping in the names associated to a1. In the same way, a2 will be mapped to b2. It's OK to assume that there will always be a winner: no need to break ties.

I was hoping that I could use groupBy(df.id_sa) on my dataframe, but I don't know what to do next. I was hoping for an aggregation that could produce, in the end, the following rows:

[Row(id_sa=a1, max_id_sb=b2),  Row(id_sa=a2, max_id_sb=b2)] 

But maybe I'm trying to use the wrong tool and I should just go back to using RDDs.

like image 872
Quentin Pradet Avatar asked Feb 05 '16 07:02

Quentin Pradet


People also ask

How can I show more than 20 rows in Spark?

By default Spark with Scala, Java, or with Python (PySpark), fetches only 20 rows from DataFrame show() but not all rows and the column value is truncated to 20 characters, In order to fetch/display more than 20 rows and column full value from Spark/PySpark DataFrame, you need to pass arguments to the show() method.

Where is Max in DataFrame PySpark?

We can get the maximum value from the column in the dataframe using the select() method. Using the max() method, we can get the maximum value from the column. To use this method, we have to import it from pyspark. sql.


1 Answers

Using join (it will result in more than one row in group in case of ties):

import pyspark.sql.functions as F from pyspark.sql.functions import count, col   cnts = df.groupBy("id_sa", "id_sb").agg(count("*").alias("cnt")).alias("cnts") maxs = cnts.groupBy("id_sa").agg(F.max("cnt").alias("mx")).alias("maxs")  cnts.join(maxs,    (col("cnt") == col("mx")) & (col("cnts.id_sa") == col("maxs.id_sa")) ).select(col("cnts.id_sa"), col("cnts.id_sb")) 

Using window functions (will drop ties):

from pyspark.sql.functions import row_number from pyspark.sql.window import Window  w = Window().partitionBy("id_sa").orderBy(col("cnt").desc())  (cnts   .withColumn("rn", row_number().over(w))   .where(col("rn") == 1)   .select("id_sa", "id_sb")) 

Using struct ordering:

from pyspark.sql.functions import struct  (cnts   .groupBy("id_sa")   .agg(F.max(struct(col("cnt"), col("id_sb"))).alias("max"))   .select(col("id_sa"), col("max.id_sb"))) 

See also How to select the first row of each group?

like image 140
zero323 Avatar answered Oct 02 '22 08:10

zero323