Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select column name per row for max value in PySpark

I have a dataframe like this, shown only two columns however there are many columns in original dataframe

data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   3|   5|
|ID2|   4|  12|
|ID3|   8|   3|
+---+----+----+

I want to extract the name of the column per row, which has the max value. Hence the expected output is like this

+---+----+----+-------+
| ID|colA|colB|Max_col|
+---+----+----+-------+
|ID1|   3|   5|   colB|
|ID2|   4|  12|   colB|
|ID3|   8|   3|   colA|
+---+----+----+-------+

In case of tie, where colA and colB have same value, choose the first column.

How can I achieve this in pyspark

like image 982
Hardik Gupta Avatar asked May 31 '19 06:05

Hardik Gupta


People also ask

How do you get the max value of a column in a DataFrame in PySpark?

Using the max () method, we can get the maximum value from the column, and finally, we can use the collect() method to get the maximum from the column. Where, df is the input PySpark DataFrame. column_name is the column to get the maximum value.

How do I select a specific column in PySpark?

In PySpark we can select columns using the select() function. The select() function allows us to select single or multiple columns in different formats.


Video Answer


4 Answers

You can use UDF on each row for row wise computation and use struct to pass multiple columns to udf. Hope this helps.

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from operator import itemgetter

data = [(("ID1", 3, 5,78)), (("ID2", 4, 12,45)), (("ID3", 70, 3,67))]
df = spark.createDataFrame(data, ["ID", "colA", "colB","colC"])
df.show()

+---+----+----+----+
| ID|colA|colB|colC|
+---+----+----+----+
|ID1|   3|   5|  78|
|ID2|   4|  12|  45|
|ID3|  70|   3|  70|
+---+----+----+----+
cols = df.columns

# to get max of values in a row
maxcol = F.udf(lambda row: max(row), IntegerType())
maxDF = df.withColumn("maxval", maxcol(F.struct([df[x] for x in df.columns[1:]])))
maxDF.show()

+---+----+----+----+-------+
|ID |colA|colB|colC|Max_col|
+---+----+----+----+-------+
|ID1|3   |5   |78  |78     |
|ID2|4   |12  |45  |45     |
|ID3|70  |3   |67  |70     |
+---+----+----+----+-------+

# to get max of value & corresponding column name

schema=StructType([StructField('maxval',IntegerType()),StructField('maxval_colname',StringType())])

maxcol = F.udf(lambda row: max(row,key=itemgetter(0)), schema)
maxDF = df.withColumn('maxfield', maxcol(F.struct([F.struct(df[x],F.lit(x)) for x in df.columns[1:]]))).\
select(df.columns+['maxfield.maxval','maxfield.maxval_colname'])

+---+----+----+----+------+--------------+
| ID|colA|colB|colC|maxval|maxval_colname|
+---+----+----+----+------+--------------+
|ID1| 3  | 5  | 78 | 78   | colC         |
|ID2| 4  | 12 | 45 | 45   | colC         |
|ID3| 70 | 3  | 67 | 68   | colA         |
+---+----+----+----+------+--------------+
like image 169
Suresh Avatar answered Nov 10 '22 06:11

Suresh


There are multiple options to achieve this. I am a providing example for one and can provide a hint for rest-

from pyspark.sql import functions as F
from pyspark.sql.window import Window as W
from pyspark.sql import types as T

data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   3|   5|
|ID2|   4|  12|
|ID3|   8|   3|
+---+----+----+

#Below F.array creates an array of column name and value pair like [['colA', 3], ['colB', 5]] then F.explode break this array into rows like different column and value pair should be in different rows

df = df.withColumn(
    "max_val",
    F.explode(
        F.array([
            F.array([F.lit(cl), F.col(cl)]) for cl in df.columns[1:]
        ])
    )
)
df.show()
+---+----+----+----------+
| ID|colA|colB|   max_val|
+---+----+----+----------+
|ID1|   3|   5| [colA, 3]|
|ID1|   3|   5| [colB, 5]|
|ID2|   4|  12| [colA, 4]|
|ID2|   4|  12|[colB, 12]|
|ID3|   8|   3| [colA, 8]|
|ID3|   8|   3| [colB, 3]|
+---+----+----+----------+

#Then select columns so that column name and value should be in different columns
df = df.select(
    "ID", 
    "colA", 
    "colB", 
    F.col("max_val").getItem(0).alias("col_name"),
    F.col("max_val").getItem(1).cast(T.IntegerType()).alias("col_value"),
)
df.show()
+---+----+----+--------+---------+
| ID|colA|colB|col_name|col_value|
+---+----+----+--------+---------+
|ID1|   3|   5|    colA|        3|
|ID1|   3|   5|    colB|        5|
|ID2|   4|  12|    colA|        4|
|ID2|   4|  12|    colB|       12|
|ID3|   8|   3|    colA|        8|
|ID3|   8|   3|    colB|        3|
+---+----+----+--------+---------+

# Rank column values based on ID in desc order
df = df.withColumn(
    "rank",
    F.rank().over(W.partitionBy("ID").orderBy(F.col("col_value").desc()))
)
df.show()
+---+----+----+--------+---------+----+
| ID|colA|colB|col_name|col_value|rank|
+---+----+----+--------+---------+----+
|ID2|   4|  12|    colB|       12|   1|
|ID2|   4|  12|    colA|        4|   2|
|ID3|   8|   3|    colA|        8|   1|
|ID3|   8|   3|    colB|        3|   2|
|ID1|   3|   5|    colB|        5|   1|
|ID1|   3|   5|    colA|        3|   2|
+---+----+----+--------+---------+----+

#Finally Filter rank = 1 as max value have rank 1 because we ranked desc value
df.where("rank=1").show()
+---+----+----+--------+---------+----+
| ID|colA|colB|col_name|col_value|rank|
+---+----+----+--------+---------+----+
|ID2|   4|  12|    colB|       12|   1|
|ID3|   8|   3|    colA|        8|   1|
|ID1|   3|   5|    colB|        5|   1|
+---+----+----+--------+---------+----+

Other Options are -

  • Use UDF on your base df and return column name having a max value
  • In the same example after making the column name and value column instead of rank use group by ID take max col_value. Then join with the previous df.
like image 38
Rakesh Kumar Avatar answered Nov 10 '22 06:11

Rakesh Kumar


Extending what Suresh has done.... returning appropriate the column name

from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType, StringType

import numpy as np

data = [(("ID1", 3, 5,78)), (("ID2", 4, 12,45)), (("ID3", 68, 3,67))]
df = spark.createDataFrame(data, ["ID", "colA", "colB","colC"])
df.show()

cols = df.columns
maxcol = f.udf(lambda row: cols[row.index(max(row)) +1], StringType())

maxDF = df.withColumn("Max_col", maxcol(f.struct([df[x] for x in df.columns[1:]])))
maxDF.show(truncate=False)

+---+----+----+----+------+
|ID |colA|colB|colC|Max_col|
+---+----+----+----+------+
|ID1|3   |5   |78  |colC  |
|ID2|4   |12  |45  |colC  |
|ID3|68  |3   |67  |colA  |
+---+----+----+----+------+
like image 31
Padmaraj Bhat Avatar answered Nov 10 '22 04:11

Padmaraj Bhat


try the following:

from  pyspark.sql import functions as F
data = [(("ID1", 3, 5)), (("ID2", 4, 12)), (("ID3", 8, 3))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.withColumn('max_col',
   F.when(F.col('colA') > F.col('colB'), 'colA').
     otherwise('colB')).show()

Yields:

+---+----+----+-------+
| ID|colA|colB|max_col|
+---+----+----+-------+
|ID1|   3|   5|   colB|
|ID2|   4|  12|   colB|
|ID3|   8|   3|   colA|
+---+----+----+-------+
like image 21
Elior Malul Avatar answered Nov 10 '22 05:11

Elior Malul