I have a table of two string type columns (username, friend)
and for each username, I want to collect all of its friends on one row, concatenated as strings. For example: ('username1', 'friends1, friends2, friends3')
I know MySQL does this with GROUP_CONCAT
. Is there any way to do this with Spark SQL?
agg. (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. The resulting DataFrame will also contain the grouping columns. The available aggregate methods are avg , max , min , sum , count .
You need to define a key or grouping in aggregation. You can also define an aggregation function that specifies how the transformations will be performed among the columns. If you give multiple values as input, the aggregation function will generate one result for each group.
The GROUP_CONCAT() function in MySQL is used to concatenate data from multiple rows into one field. This is an aggregate (GROUP BY) function which returns a String value, if the group contains at least one non-NULL value.
Aggregations in Spark are similar to any relational database. Aggregations are a way to group data together to look at it from a higher level, as illustrated in figure 1. Aggregation can be performed on tables, joined tables, views, etc.
Before you proceed: This operations is yet another another groupByKey
. While it has multiple legitimate applications it is relatively expensive so be sure to use it only when required.
Not exactly concise or efficient solution but you can use UserDefinedAggregateFunction
introduced in Spark 1.5.0:
object GroupConcat extends UserDefinedAggregateFunction { def inputSchema = new StructType().add("x", StringType) def bufferSchema = new StructType().add("buff", ArrayType(StringType)) def dataType = StringType def deterministic = true def initialize(buffer: MutableAggregationBuffer) = { buffer.update(0, ArrayBuffer.empty[String]) } def update(buffer: MutableAggregationBuffer, input: Row) = { if (!input.isNullAt(0)) buffer.update(0, buffer.getSeq[String](0) :+ input.getString(0)) } def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { buffer1.update(0, buffer1.getSeq[String](0) ++ buffer2.getSeq[String](0)) } def evaluate(buffer: Row) = UTF8String.fromString( buffer.getSeq[String](0).mkString(",")) }
Example usage:
val df = sc.parallelize(Seq( ("username1", "friend1"), ("username1", "friend2"), ("username2", "friend1"), ("username2", "friend3") )).toDF("username", "friend") df.groupBy($"username").agg(GroupConcat($"friend")).show ## +---------+---------------+ ## | username| friends| ## +---------+---------------+ ## |username1|friend1,friend2| ## |username2|friend1,friend3| ## +---------+---------------+
You can also create a Python wrapper as shown in Spark: How to map Python with Scala or Java User Defined Functions?
In practice it can be faster to extract RDD, groupByKey
, mkString
and rebuild DataFrame.
You can get a similar effect by combining collect_list
function (Spark >= 1.6.0) with concat_ws
:
import org.apache.spark.sql.functions.{collect_list, udf, lit} df.groupBy($"username") .agg(concat_ws(",", collect_list($"friend")).alias("friends"))
You can try the collect_list function
sqlContext.sql("select A, collect_list(B), collect_list(C) from Table1 group by A
Or you can regieter a UDF something like
sqlContext.udf.register("myzip",(a:Long,b:Long)=>(a+","+b))
and you can use this function in the query
sqlConttext.sql("select A,collect_list(myzip(B,C)) from tbl group by A")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With