Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala Spark - split vector column into separate columns in a Spark DataFrame

I have a Spark DataFrame where I have a column with Vector values. The vector values are all n-dimensional, aka with the same length. I also have a list of column names Array("f1", "f2", "f3", ..., "fn"), each corresponds to one element in the vector.

some_columns... | Features
      ...       | [0,1,0,..., 0]

to

some_columns... | f1 | f2 | f3 | ... | fn

      ...       | 0  | 1  | 0  | ... | 0

What is the best way to achieve this? I thought of one way which is to create a new DataFrame with createDataFrame(Row(Features), featureNameList) and then join with the old one, but it requires spark context to use createDataFrame. I only want to transform the existing data frame. I also know .withColumn("fi", value) but what do I do if n is large?

I'm new to Scala and Spark and couldn't find any good examples for this. I think this can be a common task. My particular case is that I used the CountVectorizer and wanted to recover each column individually for better readability instead of only having the vector result.

like image 200
Logan Yang Avatar asked Apr 19 '18 02:04

Logan Yang


1 Answers

One way could be to convert the vector column to an array<double> and then using getItem to extract individual elements.

import org.apache.spark.sql.functions._
import org.apache.spark.ml._

val df = Seq( (1 , linalg.Vectors.dense(1,0,1,1,0) ) ).toDF("id", "features")
//df: org.apache.spark.sql.DataFrame = [id: int, features: vector]

df.show
//+---+---------------------+
//|id |features             |
//+---+---------------------+
//|1  |[1.0,0.0,1.0,1.0,0.0]|
//+---+---------------------+

// A UDF to convert VectorUDT to ArrayType
val vecToArray = udf( (xs: linalg.Vector) => xs.toArray )

// Add a ArrayType Column   
val dfArr = df.withColumn("featuresArr" , vecToArray($"features") )

// Array of element names that need to be fetched
// ArrayIndexOutOfBounds is not checked.
// sizeof `elements` should be equal to the number of entries in column `features`
val elements = Array("f1", "f2", "f3", "f4", "f5")

// Create a SQL-like expression using the array 
val sqlExpr = elements.zipWithIndex.map{ case (alias, idx) => col("featuresArr").getItem(idx).as(alias) }

// Extract Elements from dfArr    
dfArr.select(sqlExpr : _*).show
//+---+---+---+---+---+
//| f1| f2| f3| f4| f5|
//+---+---+---+---+---+
//|1.0|0.0|1.0|1.0|0.0|
//+---+---+---+---+---+
like image 157
philantrovert Avatar answered Sep 21 '22 15:09

philantrovert