I want to perform some conditional branching to avoid calculating unnecessary nodes but I am noticing that if the source column in the condition statement is a UDF then the otherwise is resolved regardless:
@pandas_udf("double", PandasUDFType.SCALAR)
def udf_that_throws_exception(*cols):
raise Exception('Error')
@pandas_udf("int", PandasUDFType.SCALAR)
def simple_mul_udf(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
return result
df = spark.range(0,5)
df = df.withColumn('A', lit(1))
df = df.withColumn('B', lit(2))
df = df.withColumn('udf', simple_mul('A','B'))
df = df.withColumn('sql', expr('A*B'))
df = df.withColumn('res', when(df.sql < 100, lit(1)).otherwise(udf_that_throws(lit(0))))
The above code works as expected, the statement in this case is always true so my UDF that throws an exception is never called.
However, if i change the condition to use df.udf instead then all of a sudden the otherwise UDF is called and i get the exception even though the condition result has not changed.
I thought i might be able to obfuscate it by removing the UDF from the condition however the same result occurs regardless:
df = df.withColumn('cond', when(df.udf < 100, lit(1)).otherwise(lit(0)))
df = df.withColumn('res', when(df.cond == lit(1), lit(1)).otherwise(udf_that_throws_exception(lit(0))))
I imagine this has something to do with the way Spark optimizes, which is fine, but am looking for any way to do this without incurring the cost. Any ideas?
Edit Per request for more information. We are writing a processing engine that can accept an arbitrary model and the code generates the graph. Along the way there are junctures where we make decisions based on the state of values at runtime. We make heavy use of pandas UDF. So imagine a situation where we have multiple paths in the graph and, depending on some condition at runtime, we wish to follow one of those paths, leaving all others untouched.
I would like to encode this logic into the graph so there is no point where I have to collect and branch in the code.
The sample code I have provided is for demonstration purposes only. The issue I am facing is that if the column used in the IF statement is a UDF or, it seems, if it is derived from a UDF, then the OTHERWISE condition is always executed even if its never actually used. If the IF/ELSE are cheap operations such as literals I wouldnt mind, but what if the column UDF (perhaps on both sides) results in a large aggregation or some other length process which is actually just thrown away?
In PySpark the UDFs are computed beforehand and therefore you are getting this sub-optimal bahaviour. You can see it also from the query plan:
== Physical Plan ==
*(2) Project [id#753L, 1 AS A#755, 2 AS B#758, pythonUDF1#776 AS udf#763, CASE WHEN (pythonUDF1#776 < 100) THEN 1.0 ELSE pythonUDF2#777 END AS res#769]
+- ArrowEvalPython [simple_mul_udf(1, 2), simple_mul_udf(1, 2), udf_that_throws_exception(0)], [id#753L, pythonUDF0#775, pythonUDF1#776, pythonUDF2#777]
+- *(1) Range (0, 5, step=1, splits=8)
The ArrowEvalPython
operator is responsible for computing the UDFs and after that the condition will be evaluated in the Project
operator.
The reason why you get different behaviour when you call df.sql
in your condition (the optimal behaviour) is that this is a special case in which the value in the this column is constant (both columns A
and B
are constant) and the Spark optimizer can evaluate it beforehand (in the driver during query plan processing, before the execution of the actual job on the cluster) and thus it knows that the otherwise
branch of the condition will never have to be evaluated. If the value in this sql
column is dynamic (for example like in the id
column) the behaviour will be suboptimal again, because Spark does not know in advance that otherwise
part should never take place.
If you want to avoid this suboptimal behaviour (calling udf in otherwise
even though it is not needed), one possible solution is that you evaluate this condition inside your udf, for example as follows:
@pandas_udf("int", PandasUDFType.SCALAR)
def udf_with_cond(*cols):
result = cols[0]
for c in cols[1:]:
result *= c
if((result < 100).any()):
return result
else:
raise Exception('Error')
df = df.withColumn('res', udf_with_cond('A', 'B'))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With