Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to pass the parameter to User-Defined Function?

I have a user-defined function:

calc = udf(calculate, FloatType())

param1 = "A"

result = df.withColumn('col1', calc(col('type'), col('pos'))).groupBy('pk').sum('events')

def calculate(type, pos):
   if param1=="A":
       a, b = [ 0.05, -0.06 ]
   else:
       a, b = [ 0.15, -0.16 ]
   return a * math.pow(type, b) * max(pos, 1)

I need to pass a parameter param1 to this udf. How can I do it?

like image 584
Dinosaurius Avatar asked Nov 13 '17 09:11

Dinosaurius


People also ask

How can we pass parameters in user-defined function?

Parameterized Function The function may take arguments(s) also called parameters as input within the opening and closing parentheses, just after the function name followed by a colon.

Can we pass a parameter to a user-defined function in SQL?

User-Defined Functions - SQL Server User-defined functions are routines that accept parameters, perform an action, and return the result as a single scalar value or a result set.

How do you pass a parameter to a function in Python?

Function blocks begin with the keyword def followed by the function name and parentheses ( ( ) ). Any input parameters or arguments should be placed within these parentheses. You can also define parameters inside these parentheses.

How does a parameter pass a value to a function?

Passing by valueUse the ByVal keyword in the argument's declaration in the function or sub definition. The argument is passed by value whenever the function or sub is called. Insert parentheses around the argument in the function or sub call.


1 Answers

You can use lit or typedLit as a parameter for your udf like this:

In Python:

from pyspark.sql.functions import udf, col, lit
mult = udf(lambda value, multiplier: value * multiplier)
df = spark.sparkContext.parallelize([(1,),(2,),(3,)]).toDF()
df.select(mult(col("_1"), lit(3)))

In Scala:

import org.apache.spark.sql.functions.{udf, col, lit}
val mult = udf((value: Double, multiplier: Double) => value * multiplier)
val df = sparkContext.parallelize((1 to 10)).toDF
df.select(mult(col("value"), lit(3)))
like image 159
Paul V Avatar answered Sep 23 '22 08:09

Paul V