The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.
When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.
Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.
So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?
Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?
val sparkSession = SparkSession.builder.
master("local")
.appName("merge list test")
.getOrCreate()
val schema = StructType(
StructField("category", IntegerType) ::
StructField("array_value_1", ArrayType(StringType)) ::
StructField("array_value_2", ArrayType(StringType)) ::
Nil)
val rows = List(
Row(1, List("a", "b"), List("u", "v")),
Row(1, List("b", "c"), List("v", "w")),
Row(2, List("c", "d"), List("w")),
Row(2, List("c", "d", "e"), List("x", "y", "z"))
)
val df = sparkSession.createDataFrame(rows.asJava, schema)
val dfExploded = df.
withColumn("scalar_1", explode(col("array_value_1"))).
withColumn("scalar_2", explode(col("array_value_2")))
// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")
val dfOutput = dfExploded.groupBy("category").agg(
collect_set("scalar_1").alias("combined_values_2"),
collect_set("scalar_2").alias("combined_values_2"))
dfOutput.show()
It could be inefficient to explode
but fundamentally the operation you try to implement is simply expensive. Effectively it is just another groupByKey
and there is not much you can do here to make it better. Since you use Spark > 2.0 you could collect_list
directly and flatten:
import org.apache.spark.sql.functions.{collect_list, udf}
val flatten_distinct = udf(
(xs: Seq[Seq[String]]) => xs.flatten.distinct)
df
.groupBy("category")
.agg(
flatten_distinct(collect_list("array_value_1")),
flatten_distinct(collect_list("array_value_2"))
)
In Spark >= 2.4 you can replace udf with composition of built-in functions:
import org.apache.spark.sql.functions.{array_distinct, flatten}
val flatten_distinct = (array_distinct _) compose (flatten _)
It is also possible to use custom Aggregator
but I doubt any of these will make a huge difference.
If sets are relatively large and you expect significant number of duplicates you could try to use aggregateByKey
with mutable sets:
import scala.collection.mutable.{Set => MSet}
val rdd = df
.select($"category", struct($"array_value_1", $"array_value_2"))
.as[(Int, (Seq[String], Seq[String]))]
.rdd
val agg = rdd
.aggregateByKey((MSet[String](), MSet[String]()))(
{case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
{case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
)
.mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
.toDF
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With