I am trying to create a user-defined aggregate function (UDAF) in Java using Apache Spark SQL that returns multiple arrays on completion. I have searched online and cannot find any examples or suggestions on how to do this.
I am able to return a single array, but cannot figure out how to get the data in the correct format in the evaluate() method for returning multiple arrays.
The UDAF does work as I can print out the arrays in the evaluate() method, I just can't figure out how to return those arrays to the calling code (which is shown below for reference).
UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
I have included the whole custom UDAF class below, but the key methods are the dataType() and evaluate methods(), which are shown first.
Any help or advice would be greatly appreciated. Thank you.
public class CustomUDAF extends UserDefinedAggregateFunction {
@Override
public DataType dataType() {
// TODO: Is this the correct way to return 2 arrays?
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public Object evaluate(Row buffer) {
// Data conversion
List<Long> longList = new ArrayList<Long>(buffer.getList(0));
List<Double> dataList = new ArrayList<Double>(buffer.getList(1));
// Processing of data (omitted)
// TODO: How to get data into format needed to return 2 arrays?
return dataList;
}
@Override
public StructType inputSchema() {
return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
}
@Override
public StructType bufferSchema() {
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new ArrayList<Long>());
buffer.update(1, new ArrayList<Double>());
}
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
ArrayList<Long> longList = new ArrayList<Long>(buffer.getList(0));
longList.add(row.getLong(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer.getList(1));
dataList.add(row.getDouble(1));
buffer.update(0, longList);
buffer.update(1, dataList);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
ArrayList<Long> longList = new ArrayList<Long>(buffer1.getList(0));
longList.addAll(buffer2.getList(0));
ArrayList<Double> dataList = new ArrayList<Double>(buffer1.getList(1));
dataList.addAll(buffer2.getList(1));
buffer1.update(0, longList);
buffer1.update(1, dataList);
}
@Override
public boolean deterministic() {
return true;
}
}
Update: Based on the answer by zero323 I was able to return two arrays using:
return new Tuple2<>(longArray, dataArray);
Getting the data out of this was a bit of a struggle but involved deconstructing the DataFrame to Java Lists and then building it back to a DataFrame.
Spark SQL supports integration of Hive UDFs, UDAFs and UDTFs. Similar to Spark UDFs and UDAFs, Hive UDFs work on a single row as input and generate a single row as output, while Hive UDAFs operate on multiple rows and return a single aggregated row as a result.
agg is a DataFrame method that accepts those aggregate functions as arguments: scala> my_df.agg(min("column")) res0: org.apache.spark.sql. DataFrame = [min(column): double]
Description. User-Defined Aggregate Functions (UDAFs) are user-programmable routines that act on multiple rows at once and return a single aggregated value as a result. This documentation lists the classes that are required for creating and registering UDAFs.
An aggregate function performs a calculation on a set of values, and returns a single value. Except for COUNT(*) , aggregate functions ignore null values. Aggregate functions are often used with the GROUP BY clause of the SELECT statement. All aggregate functions are deterministic.
As far as I can tell returning a tuple should be just enough. In Scala:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}
object DummyUDAF extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("x", StringType)
def bufferSchema = new StructType()
.add("buff", ArrayType(LongType))
.add("buff2", ArrayType(DoubleType))
def dataType = new StructType()
.add("xs", ArrayType(LongType))
.add("ys", ArrayType(DoubleType))
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {}
def update(buffer: MutableAggregationBuffer, input: Row) = {}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {}
def evaluate(buffer: Row) = (Array(1L, 2L, 3L), Array(1.0, 2.0, 3.0))
}
val df = sc.parallelize(Seq(("a", 1), ("b", 2))).toDF("k", "v")
df.select(DummyUDAF($"k")).show(1, false)
// +---------------------------------------------------+
// |(DummyUDAF$(k),mode=Complete,isDistinct=false) |
// +---------------------------------------------------+
// |[WrappedArray(1, 2, 3),WrappedArray(1.0, 2.0, 3.0)]|
// +---------------------------------------------------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With