Spark 2.4 introduced the new SQL function slice
, which can be used extract a certain range of elements from an array column.
I want to define that range dynamically per row, based on an Integer column that has the number of elements I want to pick from that column.
However, simply passing the column to the slice function fails, the function appears to expect integers for start and end values. Is there a way of doing this without writing a UDF?
To visualize the problem with an example:
I have a dataframe with an array column arr
that has in each of the rows an array that looks like ['a', 'b', 'c']
. There also is an end_idx
column that has elements 3
, 1
and 2
:
+---------+-------+
|arr |end_idx|
+---------+-------+
|[a, b, c]|3 |
|[a, b, c]|1 |
|[a, b, c]|2 |
+---------+-------+
I try to create a new column arr_trimmed
like this:
import pyspark.sql.functions as F
l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])
df = df.withColumn("arr_trimmed", F.slice(F.col("arr"), 1, F.col("end_idx")))
I expect this code to create the new column with elements ['a', 'b', 'c']
, ['a']
, ['a', 'b']
Instead I get an error TypeError: Column is not iterable
.
You can do it by passing a SQL expression as follows:
df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)"))
Here is the whole working example:
import pyspark.sql.functions as F
l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])
df.withColumn("arr_trimmed", F.expr("slice(arr, 1, end_idx)")).show(truncate=False)
+---------+-------+-----------+
|arr |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3 |[a, b, c] |
|[a, b, c]|1 |[a] |
|[a, b, c]|2 |[a, b] |
+---------+-------+-----------+
As of Spark 2.4.0, slice receives columns as arguments. Therefore it can be used as follows:
df.withColumn("arr_trimmed", F.slice(arr, F.lit(1), end_idx))
David Vrba's example can be rewritten this way:
import pyspark.sql.functions as F
l = [(['a', 'b', 'c'], 3), (['a', 'b', 'c'], 1), (['a', 'b', 'c'], 2)]
df = spark.createDataFrame(l, ["arr", "end_idx"])
df.withColumn("arr_trimmed", F.slice("arr", F.lit(1), F.col("end_idx"))).show(truncate=False)
+---------+-------+-----------+
|arr |end_idx|arr_trimmed|
+---------+-------+-----------+
|[a, b, c]|3 |[a, b, c] |
|[a, b, c]|1 |[a] |
|[a, b, c]|2 |[a, b] |
+---------+-------+-----------+
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With