How can I use the SPARK DataFrame API to group by id, compute all value combinations within a group, and produce a single output dataframe?
Example:
val testSchema = StructType(Array(
StructField("id", IntegerType),
StructField("value", StringType)))
val test_rows = Seq(
Row(1, "a"),
Row(1, "b"),
Row(1, "c"),
Row(2, "a"),
Row(2, "d"),
Row(2, "e")
)
val test_rdd = sc.parallelize(test_rows)
val test_df = sqlContext.createDataFrame(test_rdd, testSchema)
Expected output:
1 a b
1 a c
1 b c
2 a d
2 a e
2 d e
Best solution so far:
Perform self join, filter on id equality and eliminate equal values
val result = test_df.join(
test_df.select(test_df.col("id").as("r_id"), test_df.col("value").as("r_value")),
($"id" === $"r_id") and ($"value" !== $"r_value")).select("id", "value", "r_value")
+---+-----+-------+
| id|value|r_value|
+---+-----+-------+
| 1| a| b|
| 1| a| c|
| 1| b| a|
| 1| b| c|
| 1| c| a|
| 1| c| b|
| 2| a| d|
| 2| a| e|
| 2| d| a|
| 2| d| e|
| 2| e| a|
| 2| e| d|
+---+-----+-------+
Remaining problem: how to eliminate duplicate sets, e.g., (a,b) and (b,a) while performing a join?
Do you have an ordering on the objects in the value field? If so, it seems like you could just join the dataframe with itself, while requiring that the ids be identical and value from the left table be less than the value from the right table.
[edit]If you don't have an ordering, and you have sufficiently few values per id, another solution is to use groupByKey and then create all combinations from the resulting sequence, which can be done more simply than creating all pairs and then only keeping half. (If you're using Scala, for example, I believe Seq's combination function [doc] will do what you need it to.) This will perform much worse than the self-join approach for most datasets.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With