Assuming I have the following DataFrame:
+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
| 1| null| 3|null| 11|
| 2| null| 2| xxx| 22|
| 1| null| 1| yyy|null|
| 2| null| 7|null| 33|
| 1| null| 12|null|null|
| 2| null| 19|null| 77|
| 1| null| 10| s13|null|
| 2| null| 11| a23|null|
+---+--------+---+----+----+
here is the same sample DF with comments, sorted by grp
and ord
:
scala> df.orderBy("grp", "ord").show
+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
| 1| null| 1| yyy|null|
| 1| null| 3|null| 11| # grp:1 - last value for `col2` (11)
| 1| null| 10| s13|null| # grp:1 - last value for `col1` (s13)
| 1| null| 12|null|null| # grp:1 - last values for `null_col`, `ord`
| 2| null| 2| xxx| 22|
| 2| null| 7|null| 33|
| 2| null| 11| a23|null| # grp:2 - last value for `col1` (a23)
| 2| null| 19|null| 77| # grp:2 - last values for `null_col`, `ord`, `col2`
+---+--------+---+----+----+
I would like to compress it. I.e. to group it by column "grp"
and for each group, sort rows by the "ord"
column and take the last not null
value in each column (if there is one).
+---+--------+---+----+----+
|grp|null_col|ord|col1|col2|
+---+--------+---+----+----+
| 1| null| 12| s13| 11|
| 2| null| 19| a23| 77|
+---+--------+---+----+----+
I've seen the following similar questions:
but my real DataFrame has over 250 columns, so I need a solution where I don't have to specify all the columns explicitly.
I can't wrap my head around it...
MCVE: how to create a sample DataFrame:
readSparkOutput()
:parse "/tmp/data.txt" to DataFrame:
val df = readSparkOutput("file:///tmp/data.txt")
UPDATE: I think it should be similar to the following SQL:
SELECT
grp, ord, null_col, col1, col2
FROM (
SELECT
grp,
ord,
FIRST(null_col) OVER (PARTITION BY grp ORDER BY ord DESC) as null_col,
FIRST(col1) OVER (PARTITION BY grp ORDER BY ord DESC) as col1,
FIRST(col2) OVER (PARTITION BY grp ORDER BY ord DESC) as col2,
ROW_NUMBER() OVER (PARTITION BY grp ORDER BY ord DESC) as rn
FROM table_name) as v
WHERE v.rn = 1;
how can we dynamically generate such a Spark query?
I tried the following simplified approach:
import org.apache.spark.sql.expressions.Window
val win = Window
.partitionBy("grp")
.orderBy($"ord".desc)
val cols = df.columns.map(c => first(c, ignoreNulls=true).over(win).as(c))
which produces:
scala> cols
res23: Array[org.apache.spark.sql.Column] = Array(first(grp, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `grp`, first(null_col, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `null_col`, first(ord, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `ord`, first(col1, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `col1`, first(col2, true) OVER (PARTITION BY grp ORDER BY ord DESC NULLS LAST UnspecifiedFrame) AS `col2`)
but i couldn't pass it to df.select
:
scala> df.select(cols.head, cols.tail: _*).show
<console>:34: error: no `: _*' annotation allowed here
(such annotations are only allowed in arguments to *-parameters)
df.select(cols.head, cols.tail: _*).show
another attempt:
scala> df.select(cols.map(col): _*).show
<console>:34: error: type mismatch;
found : String => org.apache.spark.sql.Column
required: org.apache.spark.sql.Column => ?
df.select(cols.map(col): _*).show
Consider the following approach that applies Window function last(c, ignoreNulls=true)
ordered by "ord" per "grp" to each of the selected columns; followed by a groupBy("grp")
to fetch the first
agg(colFcnMap) result:
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val df0 = Seq(
(1, 3, None, Some(11)),
(2, 2, Some("aaa"), Some(22)),
(1, 1, Some("s12"), None),
(2, 7, None, Some(33)),
(1, 12, None, None),
(2, 19, None, Some(77)),
(1, 10, Some("s13"), None),
(2, 11, Some("a23"), None)
).toDF("grp", "ord", "col1", "col2")
val df = df0.withColumn("null_col", lit(null))
df.orderBy("grp", "ord").show
// +---+---+----+----+--------+
// |grp|ord|col1|col2|null_col|
// +---+---+----+----+--------+
// | 1| 1| s12|null| null|
// | 1| 3|null| 11| null|
// | 1| 10| s13|null| null|
// | 1| 12|null|null| null|
// | 2| 2| aaa| 22| null|
// | 2| 7|null| 33| null|
// | 2| 11| a23|null| null|
// | 2| 19|null| 77| null|
// +---+---+----+----+--------+
val win = Window.partitionBy("grp").orderBy("ord").
rowsBetween(0, Window.unboundedFollowing)
val nonAggCols = Array("grp")
val cols = df.columns.diff(nonAggCols) // Columns to be aggregated
val colFcnMap = cols.zip(Array.fill(cols.size)("first")).toMap
// colFcnMap: scala.collection.immutable.Map[String,String] =
// Map(ord -> first, col1 -> first, col2 -> first, null_col -> first)
cols.foldLeft(df)((acc, c) =>
acc.withColumn(c, last(c, ignoreNulls=true).over(win))
).
groupBy("grp").agg(colFcnMap).
select(col("grp") :: colFcnMap.toList.map{case (c, f) => col(s"$f($c)").as(c)}: _*).
show
// +---+---+----+----+--------+
// |grp|ord|col1|col2|null_col|
// +---+---+----+----+--------+
// | 1| 12| s13| 11| null|
// | 2| 19| a23| 77| null|
// +---+---+----+----+--------+
Note that the final select
is for stripping the function name (in this case first()
) from the aggregated column names.
I have worked something out, here is the code and output
import org.apache.spark.sql.functions._
import spark.implicits._
val df0 = Seq(
(1, 3, None, Some(11)),
(2, 2, Some("aaa"), Some(22)),
(1, 1, Some("s12"), None),
(2, 7, None, Some(33)),
(1, 12, None, None),
(2, 19, None, Some(77)),
(1, 10, Some("s13"), None),
(2, 11, Some("a23"), None)
).toDF("grp", "ord", "col1", "col2")
df0.show()
//+---+---+----+----+
//|grp|ord|col1|col2|
//+---+---+----+----+
//| 1| 3|null| 11|
//| 2| 2| aaa| 22|
//| 1| 1| s12|null|
//| 2| 7|null| 33|
//| 1| 12|null|null|
//| 2| 19|null| 77|
//| 1| 10| s13|null|
//| 2| 11| a23|null|
//+---+---+----+----+
Ordering the data on first 2 columns
val df1 = df0.select("grp", "ord", "col1", "col2").orderBy("grp", "ord")
df1.show()
//+---+---+----+----+
//|grp|ord|col1|col2|
//+---+---+----+----+
//| 1| 1| s12|null|
//| 1| 3|null| 11|
//| 1| 10| s13|null|
//| 1| 12|null|null|
//| 2| 2| aaa| 22|
//| 2| 7|null| 33|
//| 2| 11| a23|null|
//| 2| 19|null| 77|
//+---+---+----+----+
val df2 = df1.groupBy("grp").agg(max("ord").alias("ord"),collect_set("col1").alias("col1"),collect_set("col2").alias("col2"))
val df3 = df2.withColumn("new_col1",$"col1".apply(size($"col1").minus(1))).withColumn("new_col2",$"col2".apply(size($"col2").minus(1)))
df3.show()
//+---+---+----------+------------+--------+--------+
//|grp|ord| col1| col2|new_col1|new_col2|
//+---+---+----------+------------+--------+--------+
//| 1| 12|[s12, s13]| [11]| s13| 11|
//| 2| 19|[aaa, a23]|[33, 22, 77]| a23| 77|
//+---+---+----------+------------+--------+--------+
You can drop the columns you don't need by using .drop("column_name")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With