Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark Dataframe select based on column index

How do I select all the columns of a dataframe that has certain indexes in Scala?

For example if a dataframe has 100 columns and i want to extract only columns (10,12,13,14,15), how to do the same?

Below selects all columns from dataframe df which has the column name mentioned in the Array colNames:

df = df.select(colNames.head,colNames.tail: _*)

If there is similar, colNos array which has

colNos = Array(10,20,25,45)

How do I transform the above df.select to fetch only those columns at the specific indexes.

like image 954
Vikas J Avatar asked Apr 22 '17 00:04

Vikas J


3 Answers

Example: Grab first 14 columns of Spark Dataframe by Index using Scala.

import org.apache.spark.sql.functions.col

// Gives array of names by index (first 14 cols for example)
val sliceCols = df.columns.slice(0, 14)
// Maps names & selects columns in dataframe
val subset_df = df.select(sliceCols.map(name=>col(name)):_*)

You cannot simply do this (as I tried and failed):

// Gives array of names by index (first 14 cols for example)
val sliceCols = df.columns.slice(0, 14)
// Maps names & selects columns in dataframe
val subset_df = df.select(sliceCols)

The reason is that you have to convert your datatype of Array[String] to Array[org.apache.spark.sql.Column] in order for the slicing to work.

OR Wrap it in a function using Currying (high five to my colleague for this):

// Subsets Dataframe to using beg_val & end_val index.
def subset_frame(beg_val:Int=0, end_val:Int)(df: DataFrame): DataFrame = {
  val sliceCols = df.columns.slice(beg_val, end_val)
  return df.select(sliceCols.map(name => col(name)):_*)
}

// Get first 25 columns as subsetted dataframe
val subset_df:DataFrame = df_.transform(subset_frame(0, 25))
like image 118
kevin_theinfinityfund Avatar answered Oct 04 '22 17:10

kevin_theinfinityfund


You can map over columns:

import org.apache.spark.sql.functions.col

df.select(colNos map df.columns map col: _*)

or:

df.select(colNos map (df.columns andThen col): _*)

or:

df.select(colNos map (col _ compose df.columns): _*)

All the methods shown above are equivalent and don't impose performance penalty. Following mapping:

colNos map df.columns 

is just a local Array access (constant time access for each index) and choosing between String or Column based variant of select doesn't affect the execution plan:

val df = Seq((1, 2, 3 ,4, 5, 6)).toDF

val colNos = Seq(0, 3, 5)

df.select(colNos map df.columns map col: _*).explain
== Physical Plan ==
LocalTableScan [_1#46, _4#49, _6#51]
df.select("_1", "_4", "_6").explain
== Physical Plan ==
LocalTableScan [_1#46, _4#49, _6#51]
like image 10
zero323 Avatar answered Nov 10 '22 21:11

zero323


@user6910411's answer above works like a charm and the number of tasks/logical plan is similar to my approach below. BUT my approach is a bit faster.
So,
I would suggest you to go with the column names rather than column numbers. Column names are much safer and much ligher than using numbers. You can use the following solution :

val colNames = Seq("col1", "col2" ...... "col99", "col100")

val selectColNames = Seq("col1", "col3", .... selected column names ... )

val selectCols = selectColNames.map(name => df.col(name))

df = df.select(selectCols:_*)

If you are hesitant to write all the 100 column names then there is a shortcut method too

val colNames = df.schema.fieldNames
like image 4
Ramesh Maharjan Avatar answered Nov 10 '22 20:11

Ramesh Maharjan