Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Convert string list to binary list in pyspark

I have a dataframe like this

data = [(("ID1", ['October', 'September', 'August'])), (("ID2", ['August', 'June', 'May'])), 
    (("ID3", ['October', 'June']))]
df = spark.createDataFrame(data, ["ID", "MonthList"])
df.show(truncate=False)

+---+----------------------------+
|ID |MonthList                   |
+---+----------------------------+
|ID1|[October, September, August]|
|ID2|[August, June, May]         |
|ID3|[October, June]             |
+---+----------------------------+

I want to compare every row with a default list, such that if the value is present assign 1 else 0

default_month_list = ['October', 'September', 'August', 'July', 'June', 'May']

Hence my expected output is this

+---+----------------------------+------------------+
|ID |MonthList                   |Binary_MonthList  |
+---+----------------------------+------------------+
|ID1|[October, September, August]|[1, 1, 1, 0, 0, 0]|
|ID2|[August, June, May]         |[0, 0, 1, 0, 1, 1]|
|ID3|[October, June]             |[1, 0, 0, 0, 1, 0]|
+---+----------------------------+------------------+

I am able to do this in python, but don't know how to do this in pyspark

like image 458
Hardik Gupta Avatar asked Dec 31 '22 14:12

Hardik Gupta


2 Answers

You can try to use such a udf.

from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, IntegerType

default_month_list = ['October', 'September', 'August', 'July', 'June', 'May']

def_month_list_func = udf(lambda x: [1 if i in x else 0 for i in default_month_list], ArrayType(IntegerType()))

df = df.withColumn("Binary_MonthList", def_month_list_func(col("MonthList")))

df.show()
# output
+---+--------------------+------------------+
| ID|           MonthList|  Binary_MonthList|
+---+--------------------+------------------+
|ID1|[October, Septemb...|[1, 1, 1, 0, 0, 0]|
|ID2| [August, June, May]|[0, 0, 1, 0, 1, 1]|
|ID3|     [October, June]|[1, 0, 0, 0, 1, 0]|
+---+--------------------+------------------+
like image 125
pissall Avatar answered Jan 03 '23 05:01

pissall


How about using array_contains():

from pyspark.sql.functions import array, array_contains        

df.withColumn('Binary_MonthList', array([array_contains('MonthList', c).astype('int') for c in default_month_list])).show()                                                                                                         
+---+--------------------+------------------+
| ID|           MonthList|  Binary_MonthList|
+---+--------------------+------------------+
|ID1|[October, Septemb...|[1, 1, 1, 0, 0, 0]|
|ID2| [August, June, May]|[0, 0, 1, 0, 1, 1]|
|ID3|     [October, June]|[1, 0, 0, 0, 1, 0]|
+---+--------------------+------------------+
like image 38
jxc Avatar answered Jan 03 '23 04:01

jxc