Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark dataframe: Count elements in array or list

Let us assume dataframe df as:

df.show()

Output:

+------+----------------+
|letter| list_of_numbers|
+------+----------------+
|     A|    [3, 1, 2, 3]|
|     B|    [1, 2, 1, 1]|
+------+----------------+

What I want to do is to count number of a specific element in column list_of_numbers. Something like this:

+------+----------------+----+
|letter| list_of_numbers|ones|
+------+----------------+----+
|     A|    [3, 1, 2, 3]|   1|
|     B|    [1, 2, 1, 1]|   3|
+------+----------------+----+

I have so far tried creating udf and it perfectly works, but I'm wondering if I can do it without defining any udf.

like image 330
Ala Tarighati Avatar asked Sep 28 '18 07:09

Ala Tarighati


1 Answers

You can explode the array and filter the exploded values for 1. Then groupBy and count:

from pyspark.sql.functions import col, count, explode

df.select("*", explode("list_of_numbers").alias("exploded"))\
    .where(col("exploded") == 1)\
    .groupBy("letter", "list_of_numbers")\
    .agg(count("exploded").alias("ones"))\
    .show()
#+------+---------------+----+
#|letter|list_of_numbers|ones|
#+------+---------------+----+
#|     A|   [3, 1, 2, 3]|   1|
#|     B|   [1, 2, 1, 1]|   3|
#+------+---------------+----+

In order to keep all rows, even when the count is 0, you can convert the exploded column into an indicator variable. Then groupBy and sum.

from pyspark.sql.functions import col, count, explode, sum as sum_

df.select("*", explode("list_of_numbers").alias("exploded"))\
    .withColumn("exploded", (col("exploded") == 1).cast("int"))\
    .groupBy("letter", "list_of_numbers")\
    .agg(sum_("exploded").alias("ones"))\
    .show()

Note, I have imported pyspark.sql.functions.sum as sum_ as to not overwrite the builtin sum function.

like image 194
pault Avatar answered Oct 09 '22 17:10

pault