Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

argmax in Spark DataFrames: how to retrieve the row with the maximum value

Given a Spark DataFrame df, I want to find the maximum value in a certain numeric column 'values', and obtain the row(s) where that value was reached. I can of course do this:

# it doesn't matter if I use scala or python, 
# since I hope I get this done with DataFrame API
import pyspark.sql.functions as F
max_value = df.select(F.max('values')).collect()[0][0]
df.filter(df.values == max_value).show()

but this is inefficient since it requires two passes through df.

pandas.Series/DataFrame and numpy.array have argmax/idxmax methods that do this efficiently (in one pass). So does standard python (built-in function max accepts a key parameter, so it can be used to find the index of the highest value).

What is the right approach in Spark? Note that I don't mind whether I get all the rows that where the maximum value is achieved, or just some arbitrary (non-empty!) subset of those rows.

like image 349
max Avatar asked Aug 07 '16 07:08

max


People also ask

How do you select top 5 rows in PySpark?

In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).

How can I show more than 20 rows in Spark?

By default Spark with Scala, Java, or with Python (PySpark), fetches only 20 rows from DataFrame show() but not all rows and the column value is truncated to 20 characters, In order to fetch/display more than 20 rows and column full value from Spark/PySpark DataFrame, you need to pass arguments to the show() method.

What is window function in Spark?

Window functions allow users of Spark SQL to calculate results such as the rank of a given row or a moving average over a range of input rows. They significantly improve the expressiveness of Spark's SQL and DataFrame APIs.


2 Answers

If schema is Orderable (schema contains only atomics / arrays of atomics / recursively orderable structs) you can use simple aggregations:

Python:

df.select(F.max(
    F.struct("values", *(x for x in df.columns if x != "values"))
)).first()

Scala:

df.select(max(struct(
    $"values" +: df.columns.collect {case x if x!= "values" => col(x)}: _*
))).first

Otherwise you can reduce over Dataset (Scala only) but it requires additional deserialization:

type T = ???

df.reduce((a, b) => if (a.getAs[T]("values") > b.getAs[T]("values")) a else b)

You can also oredrBy and limit(1) / take(1):

Scala:

df.orderBy(desc("values")).limit(1)
// or
df.orderBy(desc("values")).take(1)

Python:

df.orderBy(F.desc('values')).limit(1)
# or
df.orderBy(F.desc("values")).take(1)
like image 117
zero323 Avatar answered Oct 19 '22 07:10

zero323


Maybe it's an incomplete answer but you can use DataFrame's internal RDD, apply the max method and get the maximum record using a determined key.

a = sc.parallelize([
    ("a", 1, 100),
    ("b", 2, 120),
    ("c", 10, 1000),
    ("d", 14, 1000)
  ]).toDF(["name", "id", "salary"])

a.rdd.max(key=lambda x: x["salary"]) # Row(name=u'c', id=10, salary=1000)
like image 28
Alberto Bonsanto Avatar answered Oct 19 '22 06:10

Alberto Bonsanto