Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrapping a java function in pyspark

I am trying to create a user defined aggregate function which I can call from python. I tried to follow the answer to this question. I basically implemented the following (taken from here):

package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;

public class MySum extends UserDefinedAggregateFunction {
    private StructType _inputDataType;
    private StructType _bufferSchema;
    private DataType _returnDataType;

    public MySum() {
        List<StructField> inputFields = new ArrayList<StructField>();
        inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
        _inputDataType = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<StructField>();
        bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
        _bufferSchema = DataTypes.createStructType(bufferFields);

        _returnDataType = DataTypes.DoubleType;
    }

    @Override public StructType inputSchema() {
        return _inputDataType;
    }

    @Override public StructType bufferSchema() {
        return _bufferSchema;
    }

    @Override public DataType dataType() {
        return _returnDataType;
    }

    @Override public boolean deterministic() {
        return true;
    }

    @Override public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, null);
    }

    @Override public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)) {
            if (buffer.isNullAt(0)) {
                buffer.update(0, input.getDouble(0));
            } else {
                Double newValue = input.getDouble(0) + buffer.getDouble(0);
                buffer.update(0, newValue);
            }
        }
    }

    @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        if (!buffer2.isNullAt(0)) {
            if (buffer1.isNullAt(0)) {
                buffer1.update(0, buffer2.getDouble(0));
            } else {
                Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
                buffer1.update(0, newValue);
            }
        }
    }

    @Override public Object evaluate(Row buffer) {
        if (buffer.isNullAt(0)) {
            return null;
        } else {
            return buffer.getDouble(0);
        }
    }
}

I then compiled it with all dependencies and run pyspark with --jars myjar.jar

In pyspark I did:

df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row

def myCol(col):
    _f = sc._jvm.com.blu.bla.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))

I got the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))

<ipython-input-22-afcb8884e1db> in myCol(col)
      4 def myCol(col):
      5     _f = sc._jvm.com.blu.bla.MySum.apply
----> 6     return Column(_f(_to_seq(sc,[col], _to_java_column)))

TypeError: 'JavaPackage' object is not callable

I also tried adding --driver-class-path to the pyspark call but got the same result.

Also tried to access the java class through java import:

from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
    _f = jvm.bla.blu.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))

Also Tried to simply create the class (as suggested here):

a = jvm.com.bla.blu.MySum()

All are getting the same error message.

I can't seem to figure out what the problem is.

like image 473
Assaf Mendelson Avatar asked Mar 08 '16 13:03

Assaf Mendelson


People also ask

How do I import a Java class into PySpark?

If you run PySpark locally in IDE (PyCharm, etc.), to use custom classes in a jar, you can put the jar into $SPARK_HOME/jars, it will be added to class path to run Spark, check code snippet in $SPARK_HOME/bin/spark-class2.

What is PySpark wrapper?

MLlib is a wrapper over the PySpark and it is Spark's machine learning (ML) library. This library uses the data parallelism technique to store and work with data. The machine-learning API provided by the MLlib library is quite easy to use.

Is in function in PySpark?

PySpark isin() or IN operator is used to check/filter if the DataFrame values are exists/contains in the list of values. isin() is a function of Column class which returns a boolean value True if the value of the expression is contained by the evaluated values of the arguments.


1 Answers

So it seems the main issue was that all of the options to add the jar (--jars, driver class path, SPARK_CLASSPATH) do not work properly if giving a relative path. THis is probably because of issues with the working directory inside ipython as opposed to where I ran pyspark.

Once I changed this to absolute path, it works (Haven't tested it on a cluster yet but at least it works on a local installation).

Also, I am not sure if this is a bug also in the answer here as that answer uses a scala implementation, however in the java implementation I needed to do

def myCol(col):
    _f = sc._jvm.com.blu.bla.MySum().apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))

This is probably not really efficient as it creates _f each time, instead I should probably define _f outside the function (again, this would require testing on the cluster) but at least now it provides the correct functional answer

like image 79
Assaf Mendelson Avatar answered Sep 26 '22 03:09

Assaf Mendelson