Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does Spark groupBy.agg(min/max) of BigDecimal always return 0?

I'm trying to group by one column of a DataFrame, and generate the min and max values of a BigDecimal column within each of the resulting groups. The results always produce a very small (approximately 0) value.

(Similar min/max calls against a Double column produce the expected, non-zero values.)

As a simple example:

If I create the following DataFrame:

import org.apache.spark.sql.{functions => f}

case class Foo(group: String, bd_value: BigDecimal, d_value: Double)

val rdd = spark.sparkContext.parallelize(Seq(
  Foo("A", BigDecimal("1.0"), 1.0),
  Foo("B", BigDecimal("10.0"), 10.0),
  Foo("B", BigDecimal("1.0"), 1.0),
  Foo("C", BigDecimal("10.0"), 10.0),
  Foo("C", BigDecimal("10.0"), 10.0),
  Foo("C", BigDecimal("10.0"), 10.0)
))

val df = rdd.toDF()

Selecting max of either the Double or BigDecimal column returns the expected result:

df.select(f.max("d_value")).show()

// +------------+
// |max(d_value)|
// +------------+
// |        10.0|
// +------------+

df.select(f.max("bd_value")).show()

// +--------------------+
// |       max(bd_value)|
// +--------------------+
// |10.00000000000000...|
// +--------------------+

But if I group-by then aggregate, I get a reasonable result for the Double column, but near-zero values for the BigDecimal column:

df.groupBy("group").agg(f.max("d_value")).show()

// +-----+------------+
// |group|max(d_value)|
// +-----+------------+
// |    B|        10.0|
// |    C|        10.0|
// |    A|         1.0|
// +-----+------------+

df.groupBy("group").agg(f.max("bd_value")).show()

// +-----+-------------+
// |group|max(bd_value)|
// +-----+-------------+
// |    B|     1.00E-16|
// |    C|     1.00E-16|
// |    A|      1.0E-17|
// +-----+-------------+

Why does spark return a zero result for these min/max calls?

like image 714
Rick Haffey Avatar asked Feb 11 '19 23:02

Rick Haffey


People also ask

How does Spark calculate max value?

In Spark, find/select maximum (max) row per group can be calculated using window partitionBy() function and running row_number() function over window partition, let's see with a DataFrame example.

How does groupBy work in Spark?

The groupBy method is defined in the Dataset class. groupBy returns a RelationalGroupedDataset object where the agg() method is defined. Spark makes great use of object oriented programming! The RelationalGroupedDataset class also defines a sum() method that can be used to get the same result with less code.

What is AGG function in Spark?

agg(Column expr, scala.collection.Seq<Column> exprs) Compute aggregates by specifying a series of aggregate columns.

What is saveAsTable in Spark?

saveAsTable("t") . When the table is dropped, the custom table path will not be removed and the table data is still there. If no custom table path is specified, Spark will write data to a default table path under the warehouse directory. When the table is dropped, the default table path will be removed too.


1 Answers

TL;DR

There seems to be an inconsistency in how Spark treats the scale of BigDecimals that manifests in the particular case shown in the question. The code behaves as though it is converting BigDecimals to unscaled Longs using the scale of the BigDecimal object, but then converting back to BigDecimal using the scale of the schema.

This can be worked around by either

  • explicitly setting the scale on all BigDecimal values to match the DataFrame's schema using setScale, or
  • manually specifying a schema and creating the DF from an RDD[Row]

Long Version

Here is what I think is happening on my machine with Spark 2.4.0.

In the groupBy.max case, Spark is going through UnsafeRow and converting the BigDecimal to an unscaled Long and storing it as a Byte array in setDecimal at this line (as verified with print statements). Then, when it later calls getDecimal, it converts the byte array back to a BigDecimal using the scale specified in the schema.

If the scale in the original value does not match the scale in the schema, this results in an incorrect value. For example,

val foo = BigDecimal(123456)
foo.scale
0

val bytes = foo.underlying().unscaledValue().toByteArray()

// convert the bytes into BigDecimal using the original scale -- correct value
val sameValue = BigDecimal(new java.math.BigInteger(bytes), 0)
sameValue: scala.math.BigDecimal = 123456

// convert the bytes into BigDecimal using scale 18 -- wrong value
val smaller = BigDecimal(new java.math.BigInteger(bytes), 18)
smaller: scala.math.BigDecimal = 1.23456E-13

If I just select the max of the bd_value column, Spark doesn't seem to go through setDecimal. I haven't verified why, or where it goes instead.

But, this would explain the values observed in the question. Using the same case class Foo:

// This BigDecimal has scale 0
val rdd = spark.sparkContext.parallelize(Seq(Foo("C", BigDecimal(123456), 123456.0)))

// And shows with scale 0 in the DF
rdd.toDF.show
+-----+--------+--------+
|group|bd_value| d_value|
+-----+--------+--------+
|    C|  123456|123456.0|
+-----+--------+--------+

// But the schema has scale 18
rdd.toDF.printSchema
root
 |-- group: string (nullable = true)
 |-- bd_value: decimal(38,18) (nullable = true)
 |-- d_value: double (nullable = false)


// groupBy + max corrupts in the same way as converting to bytes via unscaled, then to BigDecimal with scale 18
rdd.groupBy("group").max("bd_value").show
+-----+-------------+
|group|max(bd_value)|
+-----+-------------+
|    C|  1.23456E-13|
+-----+-------------+

// This BigDecimal is forced to have the same scale as the inferred schema
val rdd = spark.sparkContext.parallelize(Seq(Foo("C",BigDecimal(123456).setScale(18), 123456.0)))

// verified the scale is 18 in the DF
+-----+--------------------+--------+
|group|            bd_value| d_value|
+-----+--------------------+--------+
|    C|123456.0000000000...|123456.0|
+-----+--------------------+--------+


// And it works as expected
rdd1.groupBy("group").max("bd_value").show

+-----+--------------------+
|group|       max(bd_value)|
+-----+--------------------+
|    C|123456.0000000000...|
+-----+--------------------+

This would also explain why, as observed in the comment, it works fine when converted from an RDD[Row] with an explicit schema.

val rdd2 = spark.sparkContext.parallelize(Seq(Row("C", BigDecimal(123456), 123456.0)))

// schema has BigDecimal scale 18
val schema = StructType(Seq(StructField("group", StringType, true), StructField("bd_value", DecimalType(38,18), true), StructField("d_value",DoubleType,false)))

// createDataFrame interprets the value into the schema's scale
val df = spark.createDataFrame(rdd2, schema)

df.show

+-----+--------------------+--------+
|group|            bd_value| d_value|
+-----+--------------------+--------+
|    C|123456.0000000000...|123456.0|
+-----+--------------------+--------+
like image 181
Jason Avatar answered Dec 02 '22 08:12

Jason