Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select the last element of an Array in a DataFrame

I'm working on a project and I'm dealing with some nested JSON date with a complicated schema/data structure. Basically what I want to do is filter out one of the columns, in a dataframe, such that I select the last element in the array. I'm totally stuck on how to do this. I hope this make sense.

Below is an example of what I'm trying to accomplish:

val singersDF = Seq(
  ("beatles", "help,hey,jude"),
  ("romeo", "eres,mia"),
  ("elvis", "this,is,an,example")
).toDF("name", "hit_songs")

val actualDF = singersDF.withColumn(
  "hit_songs",
  split(col("hit_songs"), "\\,")
)

actualDF.show(false)
actualDF.printSchema() 

+-------+-----------------------+
|name   |hit_songs              |
+-------+-----------------------+
|beatles|[help, hey, jude]      |
|romeo  |[eres, mia]            |
|elvis  |[this, is, an, example]|
+-------+-----------------------+
root
 |-- name: string (nullable = true)
 |-- hit_songs: array (nullable = true)
 |    |-- element: string (containsNull = true)

The end goal for the output would be the following, to select the last "string" in the hit_songs array.

I'm not worried about what the schema would look like afterwards.

+-------+---------+
|name   |hit_songs|
+-------+---------+
|beatles|jude     |
|romeo  |mia      |
|elvis  |example  |
+-------+---------+
like image 661
fletchr Avatar asked Nov 30 '22 08:11

fletchr


2 Answers

Since spark 2.4+, you can use element_at which supports negative indexing. As you can see in this documentation quote:

element_at(array, index) - Returns element of array at given (1-based) index. If index < 0, accesses elements from the last to the first. Returns NULL if the index exceeds the length of the array.

With that, here's how to get the last element:

import org.apache.spark.sql.functions.element_at
actualDF.withColumn("hit_songs", element_at($"hit_songs", -1))

Reproducible example:

First let's prepare a sample dataframe with an array column:

val columns = Seq("col1")
val data = Seq((Array(1,2,3)))
val rdd = spark.sparkContext.parallelize(data)
val df = rdd.toDF(columns:_*)

which looks like this:

scala> df.show()
+---------+
|     col1|
+---------+
|[1, 2, 3]|
+---------+

Then, apply element_at to get the last element as follows:

scala> df.withColumn("last_value", element_at($"col1", -1)).show()
+---------+----------+
|     col1|last_value|
+---------+----------+
|[1, 2, 3]|         3|
+---------+----------+
like image 34
Mohamed Ali JAMAOUI Avatar answered Feb 04 '23 12:02

Mohamed Ali JAMAOUI


You can use the size function to calculate the index of the desired item in the array, and then pass this as the argument of Column.apply (explicitly or implicitly):

import org.apache.spark.sql.functions._
import spark.implicits._

actualDF.withColumn("hit_songs", $"hit_songs".apply(size($"hit_songs").minus(1)))

Or:

actualDF.withColumn("hit_songs", $"hit_songs"(size($"hit_songs").minus(1)))
like image 101
Tzach Zohar Avatar answered Feb 04 '23 13:02

Tzach Zohar