Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pyspark grouped map IllegalArgumentException error

Tags:

python

pyspark

I'm having trouble getting GROUPED_MAP to work in pyspark. I've tried using sample code, including some from the spark git repo, without success. Any advice on what I need to change is appreciated.

For example:

from pyspark.sql import SparkSession
from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()


from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.master("local[*]").getOrCreate()
df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))

@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def subtract_mean(pdf):
    # pdf is a pandas.DataFrame
    v = pdf.v
    return pdf.assign(v=v - v.mean())

df.groupby("id").apply(subtract_mean).show()

Gives me the error:


py4j.protocol.Py4JJavaError: An error occurred while calling o61.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 44 in stage 7.0 failed 1 times, most recent failure: Lost task 44.0 in stage 7.0 (TID 128, localhost, executor driver): java.lang.IllegalArgumentException

I believe pyspark is set up correctly, as this runs successfully for me:

from pyspark.sql.functions import udf, struct, col
from pyspark.sql.types import * 
from pyspark.sql import SparkSession
import pyspark.sql.functions as func
import pandas as pd


spark = SparkSession.builder.master("local[*]").getOrCreate()

def sum_diff(f1, f2):
    return [f1 + f2, f1-f2]

schema = StructType([
    StructField("sum", FloatType(), False),
    StructField("diff", FloatType(), False)
])

sum_diff_udf = udf(lambda row: sum_diff(row[0], row[1]), schema)

df = spark.createDataFrame(pd.DataFrame([[1., 2.], [2., 4.]], columns=['f1', 'f2']))

df_new = df.withColumn("sum_diff", sum_diff_udf(struct([col('f1'), col('f2')])))\
    .select('*', 'sum_diff.*')
df_new.show()
like image 720
Peter Avatar asked Feb 25 '26 00:02

Peter


1 Answers

I had the same issue. For me it was solved by using the recommended version of PyArrow (0.15.1) and setting an environment variable in conf/spark-env.sh for backwards compatibility as I was using Spark 2.4.x:

ARROW_PRE_0_15_IPC_FORMAT=1

See full description here. Note that for Windows you'll need to rename conf/spark-env.sh to conf/spark-env.cmd as it won't pick up bash scripts. In that case the environment variable is:

set ARROW_PRE_0_15_IPC_FORMAT=1
like image 113
Poc275 Avatar answered Feb 26 '26 13:02

Poc275