Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PySpark; DecimalType multiplication precision loss

When doing multiplication with PySpark, it seems PySpark is losing precision.

For example, when multiple two decimals with precision 38,10, it returns 38,6 and rounds to three decimals which is the incorrect result.

from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField

schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)

df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()

Result

>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
 |-- amount: decimal(38,10) (nullable = true)
 |-- fx: decimal(38,10) (nullable = true)
 |-- amount_usd: decimal(38,6) (nullable = true)

>>> df.show()
+--------------+------------+----------+
|        amount|          fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+

Is this a bug? Is there a way to get the correct result?

like image 906
blu Avatar asked Sep 16 '19 23:09

blu


People also ask

How do you set decimal places in Pyspark?

You can use format_number to format a number to desired decimal places as stated in the official api document: Formats numeric column x to a format like '#,###,###. ##', rounded to d decimal places, and returns the result as a string column.

How do you define decimal type in Pyspark?

The DecimalType must have fixed precision (the maximum total number of digits) and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. The precision can be up to 38, the scale must be less or equal to precision.

What is double type in Pyspark?

DoubleType [source] Double data type, representing double precision floats. fromInternal (obj) Converts an internal SQL object into a native Python object. json ()


1 Answers

I think it is expected behavior.

Spark's Catalyst engine converts an expression written in an input language (e.g. Python) to Spark's internal Catalyst representation of that same type information. It then will operate on that internal representation.

If you check the file sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala in spark's source code, it's used to:

Calculates and propagates precision for fixed-precision decimals.

and

 * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
 * respectively, then the following operations have the following precision / scale:
 *   Operation    Result Precision                        Result Scale
 *   ------------------------------------------------------------------------
 *   e1 * e2      p1 + p2 + 1                             s1 + s2

Now let's look at the code for multiplication. where a function adjustPrecisionScale is called:

    case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
      val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
        DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
      } else {
        DecimalType.bounded(p1 + p2 + 1, s1 + s2)
      }
      val widerType = widerDecimalType(p1, s1, p2, s2)
      CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
        resultType, nullOnOverflow)

adjustPrecisionScale is where the magic happens, I pasted the function here so you can see the logic

  private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
    // Assumption:
    assert(precision >= scale)

    if (precision <= MAX_PRECISION) {
      // Adjustment only needed when we exceed max precision
      DecimalType(precision, scale)
    } else if (scale < 0) {
      // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
      // loss since we would cause a loss of digits in the integer part.
      // In this case, we are likely to meet an overflow.
      DecimalType(MAX_PRECISION, scale)
    } else {
      // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
      val intDigits = precision - scale
      // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
      // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
      val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
      // The resulting scale is the maximum between what is available without causing a loss of
      // digits for the integer part of the decimal and the minimum guaranteed scale, which is
      // computed above
      val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)

      DecimalType(MAX_PRECISION, adjustedScale)
    }
  }

Now let's walk through your example, we have

e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)

each has precision = 38, scale = 10, so p1=p2=38 and s1=s2=10. The product of these two shall have precision = p1+p2+1 = 77, and scale = s1 + s2 = 20

Note, the MAX_PRECISION=38 and MINIMUM_ADJUSTED_SCALE=6 here.

So p1+p2+1=77 > 38, val intDigits = precision - scale = 77 - 20 = 57 minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6

adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6

In the end, a DecimalType with precision=38, and scale = 6 is returned. That's why you see the type for amount_usd is decimal(38,6).

And in the Multiply function, both numbers have been converted to DecimalType(38,6) before doing the multiplication.

If you run your code with Decimal(38,6), i.e.

schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)

You will get

+----------+--------+----------+
|amount    |fx      |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+

Why the final number is 265.695000? that might due to other adjustment in the Multiply function. But you get the idea.

From the Multiply code, you can see we want to avoid using maximum precision when doing multiplication, if we change to 18

schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])

We get this:

+--------------+------------+------------------------+
|amount        |fx          |amount_usd              |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+

we get a better approximation to the result computed by python:

265.6949999039999754657515041

Hope this helps!

like image 182
niuer Avatar answered Oct 24 '22 00:10

niuer