Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Comparing columns in Pyspark

I am working on a PySpark DataFrame with n columns. I have a set of m columns (m < n) and my task is choose the column with max values in it.

For example:

Input: PySpark DataFrame containing :

col_1 = [1,2,3], col_2 = [2,1,4], col_3 = [3,2,5]

Ouput :

col_4 = max(col1, col_2, col_3) = [3,2,5]

There is something similar in pandas as explained in this question.

Is there any way of doing this in PySpark or should I change convert my PySpark df to Pandas df and then perform the operations?

like image 693
Hemant Avatar asked Jun 07 '16 07:06

Hemant


People also ask

How do I match two columns in Pyspark?

PySpark Concatenate Using concat() concat() function of Pyspark SQL is used to concatenate multiple DataFrame columns into a single column. It can also be used to concatenate column types string, binary, and compatible array columns.

Is between Pyspark?

The between() function in PySpark is used to select the values within the specified range. It can be used with the select() method. It will return true across all the values within the specified range. For the values that are not in the specified range, false is returned.

How can you tell if two data frames are identical in spark?

Checking the schema first and then you could do an intersection to df3 and verify that the count of df1,df2 & df3 are all equal (however this only works if there aren't duplicate rows, if there are different duplicates rows this method could still return true).


3 Answers

You can reduce using SQL expressions over a list of columns:

from pyspark.sql.functions import max as max_, col, when
from functools import reduce

def row_max(*cols):
    return reduce(
        lambda x, y: when(x > y, x).otherwise(y),
        [col(c) if isinstance(c, str) else c for c in cols]
    )

df = (sc.parallelize([(1, 2, 3), (2, 1, 2), (3, 4, 5)])
    .toDF(["a", "b", "c"]))

df.select(row_max("a", "b", "c").alias("max")))

Spark 1.5+ also provides least, greatest

from pyspark.sql.functions import greatest

df.select(greatest("a", "b", "c"))

If you want to keep name of the max you can use `structs:

from pyspark.sql.functions import struct, lit

def row_max_with_name(*cols):
    cols_ = [struct(col(c).alias("value"), lit(c).alias("col")) for c in cols]
    return greatest(*cols_).alias("greatest({0})".format(",".join(cols)))

 maxs = df.select(row_max_with_name("a", "b", "c").alias("maxs"))

And finally you can use above to find select "top" column:

from pyspark.sql.functions import max

((_, c), ) = (maxs
    .groupBy(col("maxs")["col"].alias("col"))
    .count()
    .agg(max(struct(col("count"), col("col"))))
    .first())

df.select(c)
like image 58
zero323 Avatar answered Oct 12 '22 07:10

zero323


We can use greatest

Creating DataFrame

df = spark.createDataFrame(
    [[1,2,3], [2,1,2], [3,4,5]], 
    ['col_1','col_2','col_3']
)
df.show()
+-----+-----+-----+
|col_1|col_2|col_3|
+-----+-----+-----+
|    1|    2|    3|
|    2|    1|    2|
|    3|    4|    5|
+-----+-----+-----+

Solution

from pyspark.sql.functions import greatest
df2 = df.withColumn('max_by_rows', greatest('col_1', 'col_2', 'col_3'))

#Only if you need col
#from pyspark.sql.functions import col
#df2 = df.withColumn('max', greatest(col('col_1'), col('col_2'), col('col_3')))
df2.show()

+-----+-----+-----+-----------+
|col_1|col_2|col_3|max_by_rows|
+-----+-----+-----+-----------+
|    1|    2|    3|          3|
|    2|    1|    2|          2|
|    3|    4|    5|          5|
+-----+-----+-----+-----------+
like image 43
ansev Avatar answered Oct 12 '22 08:10

ansev


You can also use the pyspark built-in least:

from pyspark.sql.functions import least, col
df = df.withColumn('min', least(col('c1'), col('c2'), col('c3')))
like image 11
mattexx Avatar answered Oct 12 '22 08:10

mattexx