I designed the following function to work with arrays of any numeric type:
def array_sum[T](item:Traversable[T])(implicit n:Numeric[T]) = item.sum
// Registers a function as a UDF so it can be used in SQL statements.
sqlContext.udf.register("array_sumD", array_sum(_:Seq[Float]))
But wanting to pass an array of type float me the following error:
// Now we can use our function directly in SparkSQL.
sqlContext.sql("SELECT array_sumD(array(5.0,1.0,2.0)) as array_sum").show
Error:
cannot resolve 'UDF(array(5.0,1.0,2.0))' due to data type mismatch: argument 1 requires array<double> type, however, 'array(5.0,1.0,2.0)' is of array<decimal(2,1)> type;
Default data type for decimal values in Spark-SQL is, well, decimal. If you cast your literals in the query into floats, and use the same UDF, it works:
sqlContext.sql(
"""SELECT array_sumD(array(
| CAST(5.0 AS FLOAT),
| CAST(1.0 AS FLOAT),
| CAST(2.0 AS FLOAT)
|)) as array_sum""".stripMargin).show
The result, as expected:
+---------+
|array_sum|
+---------+
| 8.0|
+---------+
Alternatively, if you do want to use decimals (to avoid floating point issues), you'll still have to use casting to get the right precision, plus you won't be able to use Scala's nice Numeric and sum, as decimals are read as java.math.BigDecimal. So - your code would be:
def array_sum(item:Traversable[java.math.BigDecimal]) = item.reduce((a, b) => a.add(b))
// Registers a function as a UDF so it can be used in SQL statements.
sqlContext.udf.register("array_sumD", array_sum(_:Seq[java.math.BigDecimal]))
sqlContext.sql(
"""SELECT array_sumD(array(
| CAST(5.0 AS DECIMAL(38,18)),
| CAST(1.0 AS DECIMAL(38,18)),
| CAST(2.0 AS DECIMAL(38,18))
|)) as array_sum""".stripMargin).show
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With