Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Serialize a custom transformer using python to be used within a Pyspark ML pipeline

Tags:

I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to that: https://issues.apache.org/jira/browse/SPARK-17025.

Given that there is no option provided by Pyspark ML pipeline for saving a custom transformer written in python, what are the other options to get it done? How can I implement the _to_java method in my python class that returns a compatible java object?

like image 656
TechnoIndifferent Avatar asked Dec 30 '16 16:12

TechnoIndifferent


2 Answers

As of Spark 2.3.0 there's a much, much better way to do this.

Simply extend DefaultParamsWritable and DefaultParamsReadable and your class will automatically have write and read methods that will save your params and will be used by the PipelineModel serialization system.

The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.

  • PipelineModel.read instantiates a PipelineModelReader
  • PipelineModelReader loads metadata and checks if language is 'Python'. If it's not, then the typical JavaMLReader is used (what most of these answers are designed for)
  • Otherwise, PipelineSharedReadWrite is used, which calls DefaultParamsReader.loadParamsInstance

loadParamsInstance will find class from the saved metadata. It will instantiate that class and call .load(path) on it. You can extend DefaultParamsReader and get the DefaultParamsReader.load method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load method as a starting place.

On the opposite side:

  • PipelineModel.write will check if all stages are Java (implement JavaMLWritable). If so, the typical JavaMLWriter is used (what most of these answers are designed for)
  • Otherwise, PipelineWriter is used, which checks that all stages implement MLWritable and calls PipelineSharedReadWrite.saveImpl
  • PipelineSharedReadWrite.saveImpl will call .write().save(path) on each stage.

You can extend DefaultParamsWriter to get the DefaultParamsWritable.write method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter as a starting point.

Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:

from pyspark import keyword_only from pyspark.ml import Transformer from pyspark.ml.param.shared import HasOutputCols, Param, Params from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.sql.functions import lit # for the dummy _transform  class SetValueTransformer(     Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable, ):     value = Param(         Params._dummy(),         "value",         "value to fill",     )      @keyword_only     def __init__(self, outputCols=None, value=0.0):         super(SetValueTransformer, self).__init__()         self._setDefault(value=0.0)         kwargs = self._input_kwargs         self._set(**kwargs)      @keyword_only     def setParams(self, outputCols=None, value=0.0):         """         setParams(self, outputCols=None, value=0.0)         Sets params for this SetValueTransformer.         """         kwargs = self._input_kwargs         return self._set(**kwargs)      def setValue(self, value):         """         Sets the value of :py:attr:`value`.         """         return self._set(value=value)      def getValue(self):         """         Gets the value of :py:attr:`value` or its default value.         """         return self.getOrDefault(self.value)      def _transform(self, dataset):         for col in self.getOutputCols():             dataset = dataset.withColumn(col, lit(self.getValue()))         return dataset 

Now we can use it:

from pyspark.ml import Pipeline, PipelineModel  svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)  p = Pipeline(stages=[svt]) df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"]) pm = p.fit(df) pm.transform(df).show() pm.write().overwrite().save('/tmp/example_pyspark_pipeline') pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline') print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap()) pm2.transform(df).show() 

Result:

+---+-----+-----+-----+ |key|value|    a|    b| +---+-----+-----+-----+ |  1| null|123.0|123.0| |  2|  1.0|123.0|123.0| |  3|  0.5|123.0|123.0| +---+-----+-----+-----+  matches? True +---+-----+-----+-----+ |key|value|    a|    b| +---+-----+-----+-----+ |  1| null|123.0|123.0| |  2|  1.0|123.0|123.0| |  3|  0.5|123.0|123.0| +---+-----+-----+-----+ 
like image 138
Benjamin Manns Avatar answered Sep 19 '22 04:09

Benjamin Manns


I am not sure this is the best approach, but I too need the ability to save custom Estimators, Transformers and Models that I have created in Pyspark, and also to support their use in the Pipeline API with persistence. Custom Pyspark Estimators, Transformers and Models may be created and used in the Pipeline API but cannot be saved. This poses an issue in production when the model training takes longer than an event prediction cycle.

In general, Pyspark Estimators, Transformers and Models are just wrappers around the Java or Scala equivalents and the Pyspark wrappers just marshal the parameters to and from Java via py4j. Any persisting of the model is then done on the Java side. Because of this current structure, this limits Custom Pyspark Estimators, Transformers and Models to living only in the python world.

In a previous attempt, I was able to save a single Pyspark model by using Pickle/dill serialization. This worked well, but still did not allow saving or loading back such from within the Pipeline API. But, pointed to by another SO post I was directed to the OneVsRest classifier, and inspected the _to_java and _from_java methods. They do all the heavy lifting on the Pyspark side. After looking I thought, if one had a way to save the pickle dump to an already made and supported savable java object, then it should be possible to save a Custom Pyspark Estimator, Transformer and Model with the Pipeline API.

To that end, I found the StopWordsRemover to be the ideal object to hijack because it has an attribute, stopwords, that is a list of strings. The dill.dumps method returns a pickled representation of the object as a string. The plan was to turn the string into a list and then set the stopwords parameter of a StopWordsRemover to this list. Though a list strings, I found that some of the characters would not marshal to the java object. So the characters get converted to integers then the integers to strings. This all works great for saving a single instance, and also when saving within in a Pipeline, because the Pipeline dutifully calls the _to_java method of my python class (we are still on the Pyspark side so this works). But, coming back to Pyspark from java did not in the Pipeline API.

Because I am hiding my python object in a StopWordsRemover instance, the Pipeline, when coming back to Pyspark, does not know anything about my hidden class object, it knows only it has a StopWordsRemover instance. Ideally, it would be great to subclass Pipeline and PipelineModel, but alas this brings us back to trying to serialize a Python object. To combat this, I created a PysparkPipelineWrapper that takes a Pipeline or PipelineModel and just scans the stages, looking for a coded ID in the stopwords list (remember, this is just the pickled bytes of my python object) that tells it to unwraps the list to my instance and stores it back in the stage it came from. Below is code that shows how this all works.

For any Custom Pyspark Estimator, Transformer and Model, just inherit from Identifiable, PysparkReaderWriter, MLReadable, MLWritable. Then when loading a Pipeline and PipelineModel, pass such through PysparkPipelineWrapper.unwrap(pipeline).

This method does not address using the Pyspark code in Java or Scala, but at least we can save and load Custom Pyspark Estimators, Transformers and Models and work with Pipeline API.

import dill from pyspark.ml import Transformer, Pipeline, PipelineModel from pyspark.ml.param import Param, Params from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter from pyspark.ml.feature import StopWordsRemover from pyspark.ml.wrapper import JavaParams from pyspark.context import SparkContext from pyspark.sql import Row  class PysparkObjId(object):     """     A class to specify constants used to idenify and setup python      Estimators, Transformers and Models so they can be serialized on there     own and from within a Pipline or PipelineModel.     """     def __init__(self):         super(PysparkObjId, self).__init__()      @staticmethod     def _getPyObjId():         return '4c1740b00d3c4ff6806a1402321572cb'      @staticmethod     def _getCarrierClass(javaName=False):         return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover  class PysparkPipelineWrapper(object):     """     A class to facilitate converting the stages of a Pipeline or PipelineModel     that were saved from PysparkReaderWriter.     """     def __init__(self):         super(PysparkPipelineWrapper, self).__init__()      @staticmethod     def unwrap(pipeline):         if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):             raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))          stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages         for i, stage in enumerate(stages):             if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):                 stages[i] = PysparkPipelineWrapper.unwrap(stage)             if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():                 swords = stage.getStopWords()[:-1] # strip the id                 lst = [chr(int(d)) for d in swords]                 dmp = ''.join(lst)                 py_obj = dill.loads(dmp)                 stages[i] = py_obj          if isinstance(pipeline, Pipeline):             pipeline.setStages(stages)         else:             pipeline.stages = stages         return pipeline  class PysparkReaderWriter(object):     """     A mixin class so custom pyspark Estimators, Transformers and Models may     support saving and loading directly or be saved within a Pipline or PipelineModel.     """     def __init__(self):         super(PysparkReaderWriter, self).__init__()      def write(self):         """Returns an MLWriter instance for this ML instance."""         return JavaMLWriter(self)      @classmethod     def read(cls):         """Returns an MLReader instance for our clarrier class."""         return JavaMLReader(PysparkObjId._getCarrierClass())      @classmethod     def load(cls, path):         """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""         swr_java_obj = cls.read().load(path)         return cls._from_java(swr_java_obj)      @classmethod     def _from_java(cls, java_obj):         """         Get the dumby the stopwords that are the characters of the dills dump plus our guid         and convert, via dill, back to our python instance.         """         swords = java_obj.getStopWords()[:-1] # strip the id         lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes         dmp = ''.join(lst)         py_obj = dill.loads(dmp)         return py_obj      def _to_java(self):         """         Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.         Use this list as a set of dumby stopwords and store in a StopWordsRemover instance         :return: Java object equivalent to this instance.         """         dmp = dill.dumps(self)         pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list         pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.         sc = SparkContext._active_spark_context         java_class = sc._gateway.jvm.java.lang.String         java_array = sc._gateway.new_array(java_class, len(pylist))         for i in xrange(len(pylist)):             java_array[i] = pylist[i]         _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)         _java_obj.setStopWords(java_array)         return _java_obj  class HasFake(Params):     def __init__(self):         super(HasFake, self).__init__()         self.fake = Param(self, "fake", "fake param")      def getFake(self):         return self.getOrDefault(self.fake)  class MockTransformer(Transformer, HasFake, Identifiable):     def __init__(self):         super(MockTransformer, self).__init__()         self.dataset_count = 0      def _transform(self, dataset):         self.dataset_count = dataset.count()         return dataset  class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):     def __init__(self):         super(MyTransformer, self).__init__()  def make_a_dataframe(sc):     df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()     return df  def test1():     trA = MyTransformer()     trA.dataset_count = 999     print trA.dataset_count     trA.save('test.trans')     trB = MyTransformer.load('test.trans')     print trB.dataset_count  def test2():     trA = MyTransformer()     pipeA = Pipeline(stages=[trA])     print type(pipeA)     pipeA.save('testA.pipe')     pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))     stagesAA = pipeAA.getStages()     trAA = stagesAA[0]     print trAA.dataset_count  def test3():     dfA = make_a_dataframe(sc)     trA = MyTransformer()     pipeA = Pipeline(stages=[trA]).fit(dfA)     print type(pipeA)     pipeA.save('testB.pipe')     pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))     stagesAA = pipeAA.stages     trAA = stagesAA[0]     print trAA.dataset_count     dfB = pipeAA.transform(dfA)     dfB.show() 
like image 39
dmbaker Avatar answered Sep 22 '22 04:09

dmbaker