I am trying to create a user defined aggregate function which I can call from python. I tried to follow the answer to this question. I basically implemented the following (taken from here):
package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
public class MySum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
private StructType _bufferSchema;
private DataType _returnDataType;
public MySum() {
List<StructField> inputFields = new ArrayList<StructField>();
inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
_inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
_bufferSchema = DataTypes.createStructType(bufferFields);
_returnDataType = DataTypes.DoubleType;
}
@Override public StructType inputSchema() {
return _inputDataType;
}
@Override public StructType bufferSchema() {
return _bufferSchema;
}
@Override public DataType dataType() {
return _returnDataType;
}
@Override public boolean deterministic() {
return true;
}
@Override public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
} else {
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
}
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
buffer1.update(0, buffer2.getDouble(0));
} else {
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
}
}
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
return null;
} else {
return buffer.getDouble(0);
}
}
}
I then compiled it with all dependencies and run pyspark with --jars myjar.jar
In pyspark I did:
df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))
I got the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))
<ipython-input-22-afcb8884e1db> in myCol(col)
4 def myCol(col):
5 _f = sc._jvm.com.blu.bla.MySum.apply
----> 6 return Column(_f(_to_seq(sc,[col], _to_java_column)))
TypeError: 'JavaPackage' object is not callable
I also tried adding --driver-class-path to the pyspark call but got the same result.
Also tried to access the java class through java import:
from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
_f = jvm.bla.blu.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
Also Tried to simply create the class (as suggested here):
a = jvm.com.bla.blu.MySum()
All are getting the same error message.
I can't seem to figure out what the problem is.
If you run PySpark locally in IDE (PyCharm, etc.), to use custom classes in a jar, you can put the jar into $SPARK_HOME/jars, it will be added to class path to run Spark, check code snippet in $SPARK_HOME/bin/spark-class2.
MLlib is a wrapper over the PySpark and it is Spark's machine learning (ML) library. This library uses the data parallelism technique to store and work with data. The machine-learning API provided by the MLlib library is quite easy to use.
PySpark isin() or IN operator is used to check/filter if the DataFrame values are exists/contains in the list of values. isin() is a function of Column class which returns a boolean value True if the value of the expression is contained by the evaluated values of the arguments.
So it seems the main issue was that all of the options to add the jar (--jars, driver class path, SPARK_CLASSPATH) do not work properly if giving a relative path. THis is probably because of issues with the working directory inside ipython as opposed to where I ran pyspark.
Once I changed this to absolute path, it works (Haven't tested it on a cluster yet but at least it works on a local installation).
Also, I am not sure if this is a bug also in the answer here as that answer uses a scala implementation, however in the java implementation I needed to do
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum().apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
This is probably not really efficient as it creates _f each time, instead I should probably define _f outside the function (again, this would require testing on the cluster) but at least now it provides the correct functional answer
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With