Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reading a custom pyspark transformer

After messing with this for quite a while, in Spark 2.3 I am finally able to get a pure python custom transformer saved. But I get an error while loading the transformer back.

I checked the content of what was saved and find all the relevant variable saved in the file on HDFS. Would be great if someone can spot what I am missing to do in this simple transformer.

from pyspark.ml import Transformer
from pyspark.ml.param.shared import Param,Params,TypeConverters

class AggregateTransformer(Transformer,DefaultParamsWritable,DefaultParamsReadable):
    aggCols = Param(Params._dummy(), "aggCols", "",TypeConverters.toListString)
    valCols = Param(Params._dummy(), "valCols", "",TypeConverters.toListString)

    def __init__(self,aggCols,valCols):
        super(AggregateTransformer, self).__init__()
        self._setDefault(aggCols=[''])
        self._set(aggCols = aggCols)
        self._setDefault(valCols=[''])
        self._set(valCols = valCols)

    def getAggCols(self):
        return self.getOrDefault(self.aggCols)

    def setAggCols(self, aggCols):
        self._set(aggCols=aggCols)

    def getValCols(self):
        return self.getOrDefault(self.valCols)

    def setValCols(self, valCols):
        self._set(valCols=valCols)

    def _transform(self, dataset):
        aggFuncs = []
        for valCol in self.getValCols():
            aggFuncs.append(F.sum(valCol).alias("sum_"+valCol))
            aggFuncs.append(F.min(valCol).alias("min_"+valCol))
            aggFuncs.append(F.max(valCol).alias("max_"+valCol))
            aggFuncs.append(F.count(valCol).alias("cnt_"+valCol))
            aggFuncs.append(F.avg(valCol).alias("avg_"+valCol))
            aggFuncs.append(F.stddev(valCol).alias("stddev_"+valCol))

        dataset = dataset.groupBy(self.getAggCols()).agg(*aggFuncs)
        return dataset

I get this error when I load an instance of this transformer after saving it.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-172-44e20f7e3842> in <module>()
----> 1 x = agg.load("/tmp/test")

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(cls, path)
    309     def load(cls, path):
    310         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 311         return cls.read().load(path)
    312 
    313 

/usr/hdp/current/spark2.3-client/python/pyspark/ml/util.py in load(self, path)
    482         metadata = DefaultParamsReader.loadMetadata(path, self.sc)
    483         py_type = DefaultParamsReader.__get_class(metadata['class'])
--> 484         instance = py_type()
    485         instance._resetUid(metadata['uid'])
    486         DefaultParamsReader.getAndSetParams(instance, metadata)

TypeError: __init__() missing 2 required positional arguments: 'aggCols' and 'valCols'
like image 536
Subramaniam Ramasubramanian Avatar asked Sep 21 '18 12:09

Subramaniam Ramasubramanian


People also ask

What are lazy transformations in pyspark RDD?

RDD Transformations are lazy operations meaning none of the transformations get executed until you call an action on PySpark RDD. Since RDD’s are immutable, any transformations on it result in a new RDD leaving the current one unchanged. There are two types of transformations.

How do I create a custom transformer in spark?

I begin by importing the necessary libraries and creating a spark session. Next create the custom transformer. This class inherits from the Transformer, HasInputCol, and HasOutputCol classes. I also call an additional parameter n which controls the maximum cardinality allowed in the tranformed column.

What is the difference between pyspark estimators and models?

This poses an issue in production when the model training takes longer than an event prediction cycle. In general, Pyspark Estimators, Transformers and Models are just wrappers around the Java or Scala equivalents and the Pyspark wrappers just marshal the parameters to and from Java via py4j.

What is it like to use pyspark?

Pyspark requires you to think about data differently. Instead of looking at a dataset row-wise. Pyspark encourages you to look at it column-wise. This was a difficult transition for me at first.


1 Answers

Figured out the answer!

The problem was that a new Transformer class was being initialized by the reader but the init function for my AggregateTransformer didnt have default values for the arguments.

So changing the following line of code fixed the issue!

def __init__(self,aggCols=[],valCols=[]):

Going to leave this answer and question here since it was incredibly difficult for me to find a working example of a pure python transformer that could be saved and read back anywhere! It could help someone looking for this.

like image 85
Subramaniam Ramasubramanian Avatar answered Oct 20 '22 19:10

Subramaniam Ramasubramanian