Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to update two columns with different values on the same condition in Pyspark?

Tags:

python

pyspark

I have a DataFrame with column a. I would like to create two additional columns (b and c) based on column a. I could solve this problem doing the same thing twice:

df = df.withColumn('b', when(df.a == 'something', 'x'))\
       .withColumn('c', when(df.a == 'something', 'y'))

I would like to avoid doing the same thing twice, as the condition on which b and c are updated are the same, and also there are a lot of cases for column a. Is there a smarter solution to this problem? Could "withColumn" accept multiple columns perhaps?

like image 508
Tamás Godányi Avatar asked Sep 04 '25 17:09

Tamás Godányi


2 Answers

A struct is best suited in such a case. See below example.

spark.sparkContext.parallelize([('something',), ('foobar',)]).toDF(['a']). \
    withColumn('b_c_struct', 
               func.when(func.col('a') == 'something', 
                         func.struct(func.lit('x').alias('b'), func.lit('y').alias('c'))
                         )
               ). \
    select('*', 'b_c_struct.*'). \
    show()

# +---------+----------+----+----+
# |        a|b_c_struct|   b|   c|
# +---------+----------+----+----+
# |something|    {x, y}|   x|   y|
# |   foobar|      null|null|null|
# +---------+----------+----+----+

Just use a drop('b_c_struct') after the select to remove the struct column and keep the individual fields.

like image 63
samkart Avatar answered Sep 07 '25 07:09

samkart


By using withColumn, you can only create or modify one column at each time. You can achieve by using rdd mapping with user defined functions, however it's not recommended:

temp = spark.createDataFrame(
    [(1, )],
    schema=['col']
)

temp.show(10, False)
+---+
|col|
+---+
|1  |
+---+



#You can create your own logic in your UDF
def user_defined_function(val, col_name):
    if col_name == 'col2':
        val += 1
    elif col_name == 'col3':
        val += 2
    else:
        pass

    return val

temp = temp.rdd.map(lambda row: (row[0], user_defined_function(row[0], 'col2'), user_defined_function(row[0], 'col3'))).toDF(['col', 'col2', 'col3'])
temp.show(3, False)
+---+----+----+
|col|col2|col3|
+---+----+----+
|1  |2   |3   |
+---+----+----+
like image 32
Jonathan Avatar answered Sep 07 '25 08:09

Jonathan