Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

SparkSQL function require type Decimal

I designed the following function to work with arrays of any numeric type:

def array_sum[T](item:Traversable[T])(implicit n:Numeric[T]) = item.sum
// Registers a function as a UDF so it can be used in SQL statements.
sqlContext.udf.register("array_sumD", array_sum(_:Seq[Float]))

But wanting to pass an array of type float me the following error:

// Now we can use our function directly in SparkSQL.
sqlContext.sql("SELECT array_sumD(array(5.0,1.0,2.0)) as array_sum").show

Error:

 cannot resolve 'UDF(array(5.0,1.0,2.0))' due to data type mismatch: argument 1 requires array<double> type, however, 'array(5.0,1.0,2.0)' is of array<decimal(2,1)> type;
like image 450
nest Avatar asked Nov 06 '25 04:11

nest


1 Answers

Default data type for decimal values in Spark-SQL is, well, decimal. If you cast your literals in the query into floats, and use the same UDF, it works:

sqlContext.sql(
  """SELECT array_sumD(array(
    |  CAST(5.0 AS FLOAT),
    |  CAST(1.0 AS FLOAT),
    |  CAST(2.0 AS FLOAT)
    |)) as array_sum""".stripMargin).show

The result, as expected:

+---------+
|array_sum|
+---------+
|      8.0|
+---------+

Alternatively, if you do want to use decimals (to avoid floating point issues), you'll still have to use casting to get the right precision, plus you won't be able to use Scala's nice Numeric and sum, as decimals are read as java.math.BigDecimal. So - your code would be:

def array_sum(item:Traversable[java.math.BigDecimal]) = item.reduce((a, b) => a.add(b))

// Registers a function as a UDF so it can be used in SQL statements.
sqlContext.udf.register("array_sumD", array_sum(_:Seq[java.math.BigDecimal]))

sqlContext.sql(
  """SELECT array_sumD(array(
    |  CAST(5.0 AS DECIMAL(38,18)),
    |  CAST(1.0 AS DECIMAL(38,18)),
    |  CAST(2.0 AS DECIMAL(38,18))
    |)) as array_sum""".stripMargin).show
like image 118
Tzach Zohar Avatar answered Nov 08 '25 00:11

Tzach Zohar