My original question here was Why using DecisionTreeModel.predict
inside map function raises an exception? and is related to How to generate tuples of (original lable, predicted label) on Spark with MLlib?
When we use Scala API a recommended way of getting predictions for RDD[LabeledPoint]
using DecisionTreeModel
is to simply map over RDD
:
val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) }
Unfortunately similar approach in PySpark doesn't work so well:
labelsAndPredictions = testData.map( lambda lp: (lp.label, model.predict(lp.features)) labelsAndPredictions.first()
Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transforamtion. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.
Instead of that official documentation recommends something like this:
predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
So what is going on here? There is no broadcast variable here and Scala API defines predict
as follows:
/** * Predict values for a single data point using the model trained. * * @param features array representing a single data point * @return Double prediction from the trained model */ def predict(features: Vector): Double = { topNode.predict(features) } /** * Predict values for the given data set using the model trained. * * @param features RDD representing data points to be predicted * @return RDD of predictions for each of the given data points */ def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) }
so at least at the first glance calling from action or transformation is not a problem since prediction seems to be a local operation.
After some digging I figured out that the source of the problem is a JavaModelWrapper.call
method invoked from DecisionTreeModel.predict. It access SparkContext
which is required to call Java function:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
In case of DecisionTreeModel.predict
there is a recommended workaround and all the required code is already a part of the Scala API but is there any elegant way to handle problem like this in general?
Only solutions I can think of right now are rather heavyweight:
Communication using default Py4J gateway is simply not possible. To understand why we have to take a look at the following diagram from the PySpark Internals document [1]:
Since Py4J gateway runs on the driver it is not accessible to Python interpreters which communicate with JVM workers through sockets (See for example PythonRDD
/ rdd.py
).
Theoretically it could be possible to create a separate Py4J gateway for each worker but in practice it is unlikely to be useful. Ignoring issues like reliability Py4J is simply not designed to perform data intensive tasks.
Are there any workarounds?
Using Spark SQL Data Sources API to wrap JVM code.
Pros: Supported, high level, doesn't require access to the internal PySpark API
Cons: Relatively verbose and not very well documented, limited mostly to the input data
Operating on DataFrames using Scala UDFs.
Pros: Easy to implement (see Spark: How to map Python with Scala or Java User Defined Functions?), no data conversion between Python and Scala if data is already stored in a DataFrame, minimal access to Py4J
Cons: Requires access to Py4J gateway and internal methods, limited to Spark SQL, hard to debug, not supported
Creating high level Scala interface in a similar way how it is done in MLlib.
Pros: Flexible, ability to execute arbitrary complex code. It can be don either directly on RDD (see for example MLlib model wrappers) or with DataFrames
(see How to use a Scala class inside Pyspark). The latter solution seems to be much more friendly since all ser-de details are already handled by existing API.
Cons: Low level, required data conversion, same as UDFs requires access to Py4J and internal API, not supported
Some basic examples can be found in Transforming PySpark RDD with Scala
Using external workflow management tool to switch between Python and Scala / Java jobs and passing data to a DFS.
Pros: Easy to implement, minimal changes to the code itself
Cons: Cost of reading / writing data (Alluxio?)
Using shared SQLContext
(see for example Apache Zeppelin or Livy) to pass data between guest languages using registered temporary tables.
Pros: Well suited for interactive analysis
Cons: Not so much for batch jobs (Zeppelin) or may require additional orchestration (Livy)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With