Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to count the trailing zeroes in an array column in a PySpark dataframe without a UDF

I have a Dataframe with a column of an array with a fixed amount of integers. How can I add to the df a column that contains the number of trailing zeroes in the array? I would like to avoid using a UDF for better performance.

For example, an input df:

>>> df.show()
+------------+
|           A|
+------------+
| [1,0,1,0,0]|
| [2,3,4,5,6]|
| [0,0,0,0,0]|
| [1,2,3,4,0]|
+------------+

And a wanted output:

>>> trailing_zeroes(df).show()
+------------+-----------------+
|           A|   trailingZeroes|
+------------+-----------------+
| [1,0,1,0,0]|                2|
| [2,3,4,5,6]|                0|
| [0,0,0,0,0]|                5|
| [1,2,3,4,0]|                1|
+------------+-----------------+
like image 775
David Taub Avatar asked Dec 04 '19 17:12

David Taub


2 Answers

When you convert the array to a string, there are several new ways to get to the result:

>>> from pyspark.sql.functions import length, regexp_extract, array_join, reverse
>>> 
>>> df = spark.createDataFrame([(1, [1, 2, 3]),
...                             (2, [2, 0]),
...                             (3, [0, 2, 3, 10]),
...                             (4, [0, 2, 3, 10, 0]),
...                             (5, [0, 1, 0, 0, 0]),
...                             (6, [0, 0, 0]),
...                             (7, [0, ]),
...                             (8, [10, ]),
...                             (9, [100, ]),
...                             (10, [0, 100, ]),
...                             (11, [])],
...                            schema=("id", "arr"))
>>> 
>>> 
>>> df.withColumn("trailing_zero_count",
...               length(regexp_extract(array_join(reverse(df.arr), ""), "^(0+)", 0))
...               ).show()
+---+----------------+-------------------+
| id|             arr|trailing_zero_count|
+---+----------------+-------------------+
|  1|       [1, 2, 3]|                  0|
|  2|          [2, 0]|                  1|
|  3|   [0, 2, 3, 10]|                  0|
|  4|[0, 2, 3, 10, 0]|                  1|
|  5| [0, 1, 0, 0, 0]|                  3|
|  6|       [0, 0, 0]|                  3|
|  7|             [0]|                  1|
|  8|            [10]|                  0|
|  9|           [100]|                  0|
| 10|        [0, 100]|                  0|
| 11|              []|                  0|
+---+----------------+-------------------+
like image 63
Oliver W. Avatar answered Sep 18 '22 13:09

Oliver W.


Since Spark 2.4 you can use Higher Order Function AGGREGATE to do that:

from pyspark.sql.functions import reverse

(
  df.withColumn("arr_rev", reverse("A"))
  .selectExpr(
    "arr_rev", 
    "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), (buffer, value) -> (if(value != 0, 0, buffer.p), if(value=0, buffer.sum + buffer.p, buffer.sum)), buffer -> buffer.sum) AS result"
  )
)

assuming A is your array with numbers. Here just be careful with data types. I am casting the initial value to LONG assuming the numbers inside the array are also longs.

like image 30
David Vrba Avatar answered Sep 20 '22 13:09

David Vrba