Is there a way to run the inference of pytorch model over a pyspark dataframe in vectorized way (using pandas_udf?).
One row udf is pretty slow since the model state_dict() needs to be loaded for each row. I'm trying to use pandas_udf to speed this up, since all the operations can be vectorized efficiently in pandas/pytorch.
I've looked at this databricks post for inspiration, but it's doesn't correspond exactly to my use case since I want to run prediction on an existing pyspark dataframe.
I can get it to work using one row udf in this simple example:
import torch
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, FloatType, DoubleType
import pandas as pd
import numpy as np
spark = SparkSession.builder.master('local[*]') \
.appName("model_training") \
.getOrCreate()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.w = nn.Linear(5, 1)
def forward(self, x):
return self.w(x)
net = Net()
bc_model_state = spark.sparkContext.broadcast(net.state_dict())
df = spark.sparkContext.parallelize([[np.random.rand() for i in range(5)] for j in range(10)]).toDF()
df = df.withColumn('features', F.array([F.col(f"_{i}") for i in range(1, 6)]))
def get_model_for_eval():
# Broadcast the model state_dict
net.load_state_dict(bc_model_state.value)
net.eval()
return net
def one_row_predict(x):
model = get_model_for_eval()
t = torch.tensor(x, dtype=torch.float32)
prediction = model(t).cpu().detach().item()
return prediction
one_row_udf = udf(one_row_predict, FloatType())
df = df.withColumn('pred_one_row', one_row_udf(col('features')))
df.show()
Output:
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
| _1| _2| _3| _4| _5| features|pred_one_row|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
| 0.8447505355266759| 0.3938414671838497|0.46347383092447003| 0.7694022276208854| 0.6152606009215115|[0.84475053552667...| 0.025048971|
|0.023782157504950607| 0.6434186254505012| 0.4090423037706754| 0.5466917794921007| 0.7855157903802007|[0.02378215750495...| 0.19694215|
| 0.5057589877333257| 0.7186078182786649| 0.9123361330966105| 0.601837718628886| 0.0773272396167538|[0.50575898773332...| 0.278222|
| 0.2815336141913932| 0.5196112020157087| 0.9646444599173869|0.04844988843812004|0.35445251642633047|[0.28153361419139...| 0.10699606|
| 0.3896101050146765|0.38732747821339863| 0.8516864705178889| 0.2500977280156421| 0.7781221754566505|[0.38961010501467...| -0.08206403|
| 0.8223344715797269| 0.9089425281658239|0.10088026161623431| 0.9920995834835098|0.40665125930441104|[0.82233447157972...| 0.3565607|
| 0.31167413110257425| 0.9778009876605741| 0.4717549025588036|0.24563879994222826| 0.7594244867194454|[0.31167413110257...| 0.18897778|
| 0.5667657426129576| 0.5383639427018171| 0.2983527299596511|0.18914810241640534|0.47854422807435326|[0.56676574261295...| 0.17796803|
| 0.6419824467244137|0.03992370080139418|0.38462617679839173| 0.709487894249459|0.23020927682221126|[0.64198244672441...| 0.15635887|
| 0.7972928622000178| 0.7700992684264264| 0.4387404431803098| 0.1340696629092989| 0.7072213018683782|[0.79729286220001...| 0.0500246|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
Trying to do the same thing with in a vectorized way, this works:
def batch_predict(x):
model = get_model_for_eval()
xp = np.vstack(x)
t = torch.tensor(xp, dtype=torch.float32)
prediction = model(t).cpu().detach().numpy().flatten()
return pd.Series(prediction)
df_pd = df.toPandas()
x = df_pd['features']
print(batch_predict(x))
But running it inside a pandas_udf fails:
batch_udf = pandas_udf(batch_predict, FloatType())
df = df.withColumn('pred_batch', batch_udf(col('features')))
df.show()
with:
20/02/11 10:13:01 ERROR Executor: Exception in task 2.0 in stage 1.0 (TID 3)
java.lang.IllegalArgumentException
at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:123)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Thanks for the help
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf as a decorator or to wrap the function, and no additional configuration is required.
This is an implementation of Pytorch on Apache Spark. The goal of this library is to provide a simple, understandable interface in distributing the training of your Pytorch model on Spark. With SparkTorch, you can easily integrate your deep learning model with an ML Spark Pipeline.
PySpark users can access the full PySpark APIs by calling DataFrame. to_spark() . pandas-on-Spark DataFrame and Spark DataFrame are virtually interchangeable.
So apparently this issue is due to an incompatibility between spark 2.4.x and pyarrow >= 0.15. See here:
How I fixed it: Call this code before creating the spark session:
import os
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With