Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Regrouping / Concatenating DataFrame rows in Spark

I have a DataFrame that looks like this:

scala> data.show
+-----+---+---------+
|label| id| features|
+-----+---+---------+
|  1.0|  1|[1.0,2.0]|
|  0.0|  2|[5.0,6.0]|
|  1.0|  1|[3.0,4.0]|
|  0.0|  2|[7.0,8.0]|
+-----+---+---------+

I want to regroup the features based on "id" so I can get the following:

scala> data.show
+---------+---+-----------------+
|    label| id| features        |
+---------+---+-----------------+
|  1.0,1.0|  1|[1.0,2.0,3.0,4.0]|
|  0.0,0.0|  2|[5.0,6.0,7.8,8.0]|
+---------+---+-----------------+

This is the code I am using to generate the mentioned DataFrame

val rdd = sc.parallelize(List((1.0, 1, Vectors.dense(1.0, 2.0)), (0.0, 2, Vectors.dense(5.0, 6.0)), (1.0, 1, Vectors.dense(3.0, 4.0)), (0.0, 2, Vectors.dense(7.0, 8.0))))
val data = rdd.toDF("label", "id", "features")

I have been trying different things with both RDD and DataFrames. The most "promising" approach so far has been to filter based on "id"

data.filter($"id".equalTo(1))

+-----+---+---------+
|label| id| features|
+-----+---+---------+
|  1.0|  1|[1.0,2.0]|
|  1.0|  1|[3.0,4.0]|
+-----+---+---------+

But I have two bottlenecks now:

1) How to automatize the filtering for all distinct values that "id" could have?

The following generates an error:

data.select("id").distinct.foreach(x => data.filter($"id".equalTo(x)))

2) How to concatenate common "features" respect to a given "id". Have not tried much since I am still stuck on 1)

Any suggestion is more than welcome

Note: For clarification "label" is always the same for every occurrence of "id". Sorry for the confusion, a simple extension of my task would be also to group the "labels" (updated example)

like image 578
Jacob Avatar asked Nov 25 '15 20:11

Jacob


People also ask

How do I concatenate a spark in a DataFrame?

Using concat() Function to Concatenate DataFrame Columns Spark SQL functions provide concat() to concatenate two or more DataFrame columns into a single Column. It can also take columns of different Data Types and concatenate them into a single column. for example, it supports String, Int, Boolean and also arrays.

How do I merge two DataFrames with different columns in spark?

In PySpark to merge two DataFrames with different columns, will use the similar approach explain above and uses unionByName() transformation. First let's create DataFrame's with different number of columns. Now add missing columns ' state ' and ' salary ' to df1 and ' age ' to df2 with null values.

How do you concatenate strings in PySpark?

How do you concatenate strings in a column in PySpark? PySpark – concat() concat() will join two or more columns in the given PySpark DataFrame and add these values into a new column. By using the select() method, we can view the column concatenated, and by using an alias() method, we can name the concatenated column.


1 Answers

I believe there is no efficient way to achieve what you want and the additional order requirement makes doesn't make situation better. The cleanest way I can think of is groupByKey like this:

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.sql.functions.monotonicallyIncreasingId
import org.apache.spark.sql.Row
import org.apache.spark.rdd.RDD


val pairs: RDD[((Double, Int), (Long, Vector))] = data
  // Add row identifiers so we can keep desired order
  .withColumn("uid", monotonicallyIncreasingId)
  // Create PairwiseRDD where (label, id) is a key
  // and (row-id, vector is a value)
  .map{case Row(label: Double, id: Int, v: Vector, uid: Long) => 
    ((label, id), (uid, v))}

val rows = pairs.groupByKey.mapValues(xs => {
  val vs = xs
    .toArray
    .sortBy(_._1) // Sort by row id to keep order
    .flatMap(_._2.toDense.values) // flatmap vector values

  Vectors.dense(vs) // return concatenated vectors 

}).map{case ((label, id), v) => (label, id, v)} // Reshape

val grouped = rows.toDF("label", "id", "features")

grouped.show

// +-----+---+-----------------+
// |label| id|         features|
// +-----+---+-----------------+
// |  0.0|  2|[5.0,6.0,7.0,8.0]|
// |  1.0|  1|[1.0,2.0,3.0,4.0]|
// +-----+---+-----------------+

It is also possible to use an UDAF similar to the one I've proposed for SPARK SQL replacement for mysql GROUP_CONCAT aggregate function but it is even less efficient than this.

like image 104
zero323 Avatar answered Sep 17 '22 02:09

zero323