Given a Spark DataFrame df
, I want to find the maximum value in a certain numeric column 'values'
, and obtain the row(s) where that value was reached. I can of course do this:
# it doesn't matter if I use scala or python,
# since I hope I get this done with DataFrame API
import pyspark.sql.functions as F
max_value = df.select(F.max('values')).collect()[0][0]
df.filter(df.values == max_value).show()
but this is inefficient since it requires two passes through df
.
pandas.Series
/DataFrame
and numpy.array
have argmax
/idxmax
methods that do this efficiently (in one pass). So does standard python (built-in function max
accepts a key parameter, so it can be used to find the index of the highest value).
What is the right approach in Spark? Note that I don't mind whether I get all the rows that where the maximum value is achieved, or just some arbitrary (non-empty!) subset of those rows.
In Spark/PySpark, you can use show() action to get the top/first N (5,10,100 ..) rows of the DataFrame and display them on a console or a log, there are also several Spark Actions like take() , tail() , collect() , head() , first() that return top and last n rows as a list of Rows (Array[Row] for Scala).
By default Spark with Scala, Java, or with Python (PySpark), fetches only 20 rows from DataFrame show() but not all rows and the column value is truncated to 20 characters, In order to fetch/display more than 20 rows and column full value from Spark/PySpark DataFrame, you need to pass arguments to the show() method.
Window functions allow users of Spark SQL to calculate results such as the rank of a given row or a moving average over a range of input rows. They significantly improve the expressiveness of Spark's SQL and DataFrame APIs.
If schema is Orderable
(schema contains only atomics / arrays of atomics / recursively orderable structs) you can use simple aggregations:
Python:
df.select(F.max(
F.struct("values", *(x for x in df.columns if x != "values"))
)).first()
Scala:
df.select(max(struct(
$"values" +: df.columns.collect {case x if x!= "values" => col(x)}: _*
))).first
Otherwise you can reduce over Dataset
(Scala only) but it requires additional deserialization:
type T = ???
df.reduce((a, b) => if (a.getAs[T]("values") > b.getAs[T]("values")) a else b)
You can also oredrBy
and limit(1)
/ take(1)
:
Scala:
df.orderBy(desc("values")).limit(1)
// or
df.orderBy(desc("values")).take(1)
Python:
df.orderBy(F.desc('values')).limit(1)
# or
df.orderBy(F.desc("values")).take(1)
Maybe it's an incomplete answer but you can use DataFrame
's internal RDD
, apply the max
method and get the maximum record using a determined key.
a = sc.parallelize([
("a", 1, 100),
("b", 2, 120),
("c", 10, 1000),
("d", 14, 1000)
]).toDF(["name", "id", "salary"])
a.rdd.max(key=lambda x: x["salary"]) # Row(name=u'c', id=10, salary=1000)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With