Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala/Spark dataframes: find the column name corresponding to the max

In Scala/Spark, having a dataframe:

val dfIn = sqlContext.createDataFrame(Seq(
  ("r0", 0, 2, 3),
  ("r1", 1, 0, 0),
  ("r2", 0, 2, 2))).toDF("id", "c0", "c1", "c2")

I would like to compute a new column maxCol holding the name of the column corresponding to the max value (for each row). With this example, the output should be:

+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

Actually the dataframe have more than 60 columns. Thus a generic solution is required.

The equivalent in Python Pandas (yes, I know, I should compare with pyspark...) could be:

dfOut = pd.concat([dfIn, dfIn.idxmax(axis=1).rename('maxCol')], axis=1) 
like image 932
ivankeller Avatar asked Feb 27 '17 11:02

ivankeller


People also ask

How do I get the maximum value of a column in Spark DataFrame?

We can get the maximum value from the column in the dataframe using the select() method. Using the max() method, we can get the maximum value from the column.

Where is Max salary in Spark Scala?

Using the Spark filter(), just select row == 1, which returns the maximum salary of each group. Finally, if a row column is not needed, just drop it.

How do you find the max and min of a column in PySpark?

The Kurtosis() function returns the kurtosis of the values present in the group. The min() function returns the minimum value currently in the column. The max() function returns the maximum value present in the queue. The mean() function returns the average of the weights current in the column.

Is Spark case sensitive for column names?

Spark can be case sensitive, but it is case insensitive by default. In order to avoid potential data corruption or data loss, duplicate column names are not allowed.


1 Answers

With a small trick you can use greatest function. Required imports:

import org.apache.spark.sql.functions.{col, greatest, lit, struct}

First let's create a list of structs, where the first element is value, and the second one column name:

val structs = dfIn.columns.tail.map(
  c => struct(col(c).as("v"), lit(c).as("k"))
)

Structure like this can be passed to greatest as follows:

dfIn.withColumn("maxCol", greatest(structs: _*).getItem("k"))
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c2|
+---+---+---+---+------+

Please note that in case of ties it will take the element which occurs later in the sequence (lexicographically (x, "c2") > (x, "c1")). If for some reason this is not acceptable you can explicitly reduce with when:

import org.apache.spark.sql.functions.when

val max_col = structs.reduce(
  (c1, c2) => when(c1.getItem("v") >= c2.getItem("v"), c1).otherwise(c2)
).getItem("k")

dfIn.withColumn("maxCol", max_col)
+---+---+---+---+------+
| id| c0| c1| c2|maxCol|
+---+---+---+---+------+
| r0|  0|  2|  3|    c2|
| r1|  1|  0|  0|    c0|
| r2|  0|  2|  2|    c1|
+---+---+---+---+------+

In case of nullable columns you have to adjust this, for example by coalescing to values to -Inf.

like image 62
zero323 Avatar answered Sep 19 '22 21:09

zero323