Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create a custom Transformer in PySpark ML

I am new to Spark SQL DataFrames and ML on them (PySpark). How can I create a custom tokenizer, which for example removes stop words and uses some libraries from nltk? Can I extend the default one?

like image 353
Niko Avatar asked Sep 01 '15 12:09

Niko


1 Answers

Can I extend the default one?

Not really. Default Tokenizer is a subclass of pyspark.ml.wrapper.JavaTransformer and, same as other transfromers and estimators from pyspark.ml.feature, delegates actual processing to its Scala counterpart. Since you want to use Python you should extend pyspark.ml.pipeline.Transformer directly.

import nltk  from pyspark import keyword_only  ## < 2.0 -> pyspark.ml.util.keyword_only from pyspark.ml import Transformer from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters # Available in PySpark >= 2.3.0  from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable   from pyspark.sql.functions import udf from pyspark.sql.types import ArrayType, StringType  class NLTKWordPunctTokenizer(         Transformer, HasInputCol, HasOutputCol,         # Credits https://stackoverflow.com/a/52467470         # by https://stackoverflow.com/users/234944/benjamin-manns         DefaultParamsReadable, DefaultParamsWritable):      stopwords = Param(Params._dummy(), "stopwords", "stopwords",                       typeConverter=TypeConverters.toListString)       @keyword_only     def __init__(self, inputCol=None, outputCol=None, stopwords=None):         super(NLTKWordPunctTokenizer, self).__init__()         self.stopwords = Param(self, "stopwords", "")         self._setDefault(stopwords=[])         kwargs = self._input_kwargs         self.setParams(**kwargs)      @keyword_only     def setParams(self, inputCol=None, outputCol=None, stopwords=None):         kwargs = self._input_kwargs         return self._set(**kwargs)      def setStopwords(self, value):         return self._set(stopwords=list(value))      def getStopwords(self):         return self.getOrDefault(self.stopwords)      # Required in Spark >= 3.0     def setInputCol(self, value):         """         Sets the value of :py:attr:`inputCol`.         """         return self._set(inputCol=value)      # Required in Spark >= 3.0     def setOutputCol(self, value):         """         Sets the value of :py:attr:`outputCol`.         """         return self._set(outputCol=value)      def _transform(self, dataset):         stopwords = set(self.getStopwords())          def f(s):             tokens = nltk.tokenize.wordpunct_tokenize(s)             return [t for t in tokens if t.lower() not in stopwords]          t = ArrayType(StringType())         out_col = self.getOutputCol()         in_col = dataset[self.getInputCol()]         return dataset.withColumn(out_col, udf(f, t)(in_col)) 

Example usage (data from ML - Features):

sentenceDataFrame = spark.createDataFrame([   (0, "Hi I heard about Spark"),   (0, "I wish Java could use case classes"),   (1, "Logistic regression models are neat") ], ["label", "sentence"])  tokenizer = NLTKWordPunctTokenizer(     inputCol="sentence", outputCol="words",       stopwords=nltk.corpus.stopwords.words('english'))  tokenizer.transform(sentenceDataFrame).show() 

For custom Python Estimator see How to Roll a Custom Estimator in PySpark mllib

⚠ This answer depends on internal API and is compatible with Spark 2.0.3, 2.1.1, 2.2.0 or later (SPARK-19348). For code compatible with previous Spark versions please see revision 8.

like image 193
zero323 Avatar answered Oct 05 '22 17:10

zero323