Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to run inference of a pytorch model on pyspark dataframe (create new column with prediction) using pandas_udf?

Is there a way to run the inference of pytorch model over a pyspark dataframe in vectorized way (using pandas_udf?).

One row udf is pretty slow since the model state_dict() needs to be loaded for each row. I'm trying to use pandas_udf to speed this up, since all the operations can be vectorized efficiently in pandas/pytorch.

I've looked at this databricks post for inspiration, but it's doesn't correspond exactly to my use case since I want to run prediction on an existing pyspark dataframe.

I can get it to work using one row udf in this simple example:

import torch
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, FloatType, DoubleType
import pandas as pd
import numpy as np

spark = SparkSession.builder.master('local[*]') \
    .appName("model_training") \
    .getOrCreate()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Linear(5, 1)

    def forward(self, x):
        return self.w(x)

net = Net()
bc_model_state = spark.sparkContext.broadcast(net.state_dict())


df = spark.sparkContext.parallelize([[np.random.rand() for i in range(5)] for j in range(10)]).toDF()
df = df.withColumn('features', F.array([F.col(f"_{i}") for i in range(1, 6)]))

def get_model_for_eval():
  # Broadcast the model state_dict
  net.load_state_dict(bc_model_state.value)
  net.eval()
  return net

def one_row_predict(x):
    model = get_model_for_eval()
    t = torch.tensor(x, dtype=torch.float32)
    prediction = model(t).cpu().detach().item()
    return prediction

one_row_udf = udf(one_row_predict, FloatType())
df = df.withColumn('pred_one_row', one_row_udf(col('features')))
df.show()

Output:

+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
|                  _1|                 _2|                 _3|                 _4|                 _5|            features|pred_one_row|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
|  0.8447505355266759| 0.3938414671838497|0.46347383092447003| 0.7694022276208854| 0.6152606009215115|[0.84475053552667...| 0.025048971|
|0.023782157504950607| 0.6434186254505012| 0.4090423037706754| 0.5466917794921007| 0.7855157903802007|[0.02378215750495...|  0.19694215|
|  0.5057589877333257| 0.7186078182786649| 0.9123361330966105|  0.601837718628886| 0.0773272396167538|[0.50575898773332...|    0.278222|
|  0.2815336141913932| 0.5196112020157087| 0.9646444599173869|0.04844988843812004|0.35445251642633047|[0.28153361419139...|  0.10699606|
|  0.3896101050146765|0.38732747821339863| 0.8516864705178889| 0.2500977280156421| 0.7781221754566505|[0.38961010501467...| -0.08206403|
|  0.8223344715797269| 0.9089425281658239|0.10088026161623431| 0.9920995834835098|0.40665125930441104|[0.82233447157972...|   0.3565607|
| 0.31167413110257425| 0.9778009876605741| 0.4717549025588036|0.24563879994222826| 0.7594244867194454|[0.31167413110257...|  0.18897778|
|  0.5667657426129576| 0.5383639427018171| 0.2983527299596511|0.18914810241640534|0.47854422807435326|[0.56676574261295...|  0.17796803|
|  0.6419824467244137|0.03992370080139418|0.38462617679839173|  0.709487894249459|0.23020927682221126|[0.64198244672441...|  0.15635887|
|  0.7972928622000178| 0.7700992684264264| 0.4387404431803098| 0.1340696629092989| 0.7072213018683782|[0.79729286220001...|   0.0500246|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+

Trying to do the same thing with in a vectorized way, this works:

def batch_predict(x):
    model = get_model_for_eval()
    xp = np.vstack(x)
    t = torch.tensor(xp, dtype=torch.float32)
    prediction = model(t).cpu().detach().numpy().flatten()
    return pd.Series(prediction)

df_pd = df.toPandas()
x = df_pd['features']
print(batch_predict(x))

But running it inside a pandas_udf fails:

batch_udf = pandas_udf(batch_predict, FloatType())
df = df.withColumn('pred_batch', batch_udf(col('features')))
df.show()

with:

20/02/11 10:13:01 ERROR Executor: Exception in task 2.0 in stage 1.0 (TID 3)
java.lang.IllegalArgumentException
    at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
    at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
    at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
    at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
    at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
    at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
    at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Thanks for the help

like image 211
Hamza Avatar asked Feb 05 '20 11:02

Hamza


People also ask

How do you use pandas UDF in PySpark?

Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF is defined using the pandas_udf as a decorator or to wrap the function, and no additional configuration is required.

Can Spark run Pytorch?

This is an implementation of Pytorch on Apache Spark. The goal of this library is to provide a simple, understandable interface in distributing the training of your Pytorch model on Spark. With SparkTorch, you can easily integrate your deep learning model with an ML Spark Pipeline.

Can we use pandas in PySpark?

PySpark users can access the full PySpark APIs by calling DataFrame. to_spark() . pandas-on-Spark DataFrame and Spark DataFrame are virtually interchangeable.


1 Answers

So apparently this issue is due to an incompatibility between spark 2.4.x and pyarrow >= 0.15. See here:

  • https://issues.apache.org/jira/browse/SPARK-29367
  • https://arrow.apache.org/blog/2019/10/06/0.15.0-release/
  • https://spark.apache.org/docs/3.0.0-preview/sql-pyspark-pandas-with-arrow.html#usage-notes

How I fixed it: Call this code before creating the spark session:

import os

os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
like image 109
Hamza Avatar answered Sep 26 '22 15:09

Hamza