When doing multiplication with PySpark, it seems PySpark is losing precision.
For example, when multiple two decimals with precision 38,10, it returns 38,6 and rounds to three decimals which is the incorrect result.
from decimal import Decimal
from pyspark.sql.types import DecimalType, StructType, StructField
schema = StructType([StructField("amount", DecimalType(38,10)), StructField("fx", DecimalType(38,10))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
df.printSchema()
df = df.withColumn("amount_usd", df.amount * df.fx)
df.printSchema()
df.show()
Result
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df = df.withColumn("amount_usd", df.amount * df.fx)
>>> df.printSchema()
root
|-- amount: decimal(38,10) (nullable = true)
|-- fx: decimal(38,10) (nullable = true)
|-- amount_usd: decimal(38,6) (nullable = true)
>>> df.show()
+--------------+------------+----------+
| amount| fx|amount_usd|
+--------------+------------+----------+
|233.0000000000|1.1403218880|265.695000|
+--------------+------------+----------+
Is this a bug? Is there a way to get the correct result?
You can use format_number to format a number to desired decimal places as stated in the official api document: Formats numeric column x to a format like '#,###,###. ##', rounded to d decimal places, and returns the result as a string column.
The DecimalType must have fixed precision (the maximum total number of digits) and scale (the number of digits on the right of dot). For example, (5, 2) can support the value from [-999.99 to 999.99]. The precision can be up to 38, the scale must be less or equal to precision.
DoubleType [source] Double data type, representing double precision floats. fromInternal (obj) Converts an internal SQL object into a native Python object. json ()
I think it is expected behavior.
Spark's Catalyst engine converts an expression written in an input language (e.g. Python) to Spark's internal Catalyst representation of that same type information. It then will operate on that internal representation.
If you check the file sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
in spark's source code, it's used to:
Calculates and propagates precision for fixed-precision decimals.
and
* In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
* respectively, then the following operations have the following precision / scale:
* Operation Result Precision Result Scale
* ------------------------------------------------------------------------
* e1 * e2 p1 + p2 + 1 s1 + s2
Now let's look at the code for multiplication. where a function adjustPrecisionScale
is called:
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType, nullOnOverflow)
adjustPrecisionScale
is where the magic happens, I pasted the function here so you can see the logic
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumption:
assert(precision >= scale)
if (precision <= MAX_PRECISION) {
// Adjustment only needed when we exceed max precision
DecimalType(precision, scale)
} else if (scale < 0) {
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
// loss since we would cause a loss of digits in the integer part.
// In this case, we are likely to meet an overflow.
DecimalType(MAX_PRECISION, scale)
} else {
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
val intDigits = precision - scale
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE)
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue)
DecimalType(MAX_PRECISION, adjustedScale)
}
}
Now let's walk through your example, we have
e1 = Decimal(233.00)
e2 = Decimal(1.1403218880)
each has precision = 38
, scale = 10
, so p1=p2=38
and s1=s2=10
. The product of these two shall have precision = p1+p2+1 = 77
, and scale = s1 + s2 = 20
Note, the MAX_PRECISION=38
and MINIMUM_ADJUSTED_SCALE=6
here.
So p1+p2+1=77 > 38
,
val intDigits = precision - scale = 77 - 20 = 57
minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) = min(20, 6) = 6
adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) = max(38-57, 6)=6
In the end, a DecimalType with precision=38, and scale = 6
is returned. That's why you see the type for amount_usd
is decimal(38,6)
.
And in the Multiply
function, both numbers have been converted to DecimalType(38,6)
before doing the multiplication.
If you run your code with Decimal(38,6)
, i.e.
schema = StructType([StructField("amount", DecimalType(38,6)), StructField("fx", DecimalType(38,6))])
df = spark.createDataFrame([(Decimal(233.00), Decimal(1.1403218880))], schema=schema)
You will get
+----------+--------+----------+
|amount |fx |amount_usd|
+----------+--------+----------+
|233.000000|1.140322|265.695026|
+----------+--------+----------+
Why the final number is 265.695000
? that might due to other adjustment in the Multiply
function. But you get the idea.
From the Multiply
code, you can see we want to avoid using maximum precision when doing multiplication, if we change to 18
schema = StructType([StructField("amount", DecimalType(18,10)), StructField("fx", DecimalType(18,10))])
We get this:
+--------------+------------+------------------------+
|amount |fx |amount_usd |
+--------------+------------+------------------------+
|233.0000000000|1.1403218880|265.69499990400000000000|
+--------------+------------+------------------------+
we get a better approximation to the result computed by python:
265.6949999039999754657515041
Hope this helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With