Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using .where() on pyspark.sql.functions.max().over(window) on Spark 2.4 throws Java exception

I followed a post on StackOverflow about returning the maximum of a column grouped by another column, and got an unexpected Java exception.

Here is the test data:

import pyspark.sql.functions as f
data = [('a', 5), ('a', 8), ('a', 7), ('b', 1), ('b', 3)]
df = spark.createDataFrame(data, ["A", "B"])
df.show()

+---+---+
|  A|  B|
+---+---+
|  a|  5|
|  a|  8|
|  a|  7|
|  b|  1|
|  b|  3|
+---+---+

Here is the solution that allegedly works for other users:

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB').show()

which should produce this output:

#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

Instead, I get:

java.lang.UnsupportedOperationException: Cannot evaluate expression: max(input[2, bigint, false]) windowspecdefinition(input[0, string, true], specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))

I have only tried this on Spark 2.4 on Databricks. I tried the equivalent SQL syntax and got the same error.

like image 237
AltShift Avatar asked Feb 03 '19 23:02

AltShift


People also ask

What is use of Spark SQL () in PySpark?

from pyspark.sql import SparkSession A spark session can be used to create the Dataset and DataFrame API. A SparkSession can also be used to create DataFrame, register DataFrame as a table, execute SQL over tables, cache table, and read parquet file.

How do you calculate Max in PySpark?

Method -1 : Using select() method Using the max() method, we can get the maximum value from the column. To use this method, we have to import it from pyspark. sql. functions module, and finally, we can use the collect() method to get the maximum from the column.

What is window function in PySpark?

PySpark Window function performs statistical operations such as rank, row number, etc. on a group, frame, or collection of rows and returns results for each row individually. It is also popularly growing to perform data transformations.

How can I show more than 20 rows in Spark?

By default Spark with Scala, Java, or with Python (PySpark), fetches only 20 rows from DataFrame show() but not all rows and the column value is truncated to 20 characters, In order to fetch/display more than 20 rows and column full value from Spark/PySpark DataFrame, you need to pass arguments to the show() method.


1 Answers

Databricks Support was able to reproduce the issue on Spark 2.4 but not on earlier versions. Apparently, it arises from a difference in the way the physical plan is formulated (I can post their response if requested). A fix is planned.

Meanwhile, here is one alternative solution to the original problem that does not fall prey to the version 2.4 issue:

df.withColumn("maxB", f.max('B').over(w)).drop('B').distinct().show()

+---+----+
|  A|maxB|
+---+----+
|  b|   3|
|  a|   8|
+---+----+
like image 90
AltShift Avatar answered Oct 01 '22 20:10

AltShift