Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stratified sampling in Spark

Tags:

I have data set which contains user and purchase data. Here is an example, where first element is userId, second is productId, and third indicate boolean.

(2147481832,23355149,1) (2147481832,973010692,1) (2147481832,2134870842,1) (2147481832,541023347,1) (2147481832,1682206630,1) (2147481832,1138211459,1) (2147481832,852202566,1) (2147481832,201375938,1) (2147481832,486538879,1) (2147481832,919187908,1) ...  

I want to make sure I only take 80% of each users data and build an RDD while take the rest of the 20% and build a another RDD. Lets call train and test. I would like to stay away from using groupBy to start with since it can create memory problem since data set is large. Whats the best way to do this?

I could do following but this will not give 80% of each user.

val percentData = data.map(x => ((math.random * 100).toInt, x._1. x._2, x._3) val train = percentData.filter(x => x._1 < 80).values.repartition(10).cache() 
like image 632
add-semi-colons Avatar asked Aug 27 '15 00:08

add-semi-colons


1 Answers

One possible solution is in Holden's answer, and here is some other solutions :

Using RDDs :

You can use the sampleByKeyExact transformation, from the PairRDDFunctions class.

sampleByKeyExact(boolean withReplacement, scala.collection.Map fractions, long seed) Return a subset of this RDD sampled by key (via stratified sampling) containing exactly math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).

And this is how I would do :

Considering the following list :

val seq = Seq(                 (2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),                 (2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),                 (2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)            ) 

I would create an RDD Pair, mapping all the users as keys :

val data = sc.parallelize(seq).map(x => (x._1,(x._2,x._3))) 

Then I'll set up fractions for each key as following, since sampleByKeyExact takes a Map of fraction for each key :

val fractions = data.map(_._1).distinct.map(x => (x,0.8)).collectAsMap 

What I have done here is mapping on the keys to find distinct keys and then associate each to a fraction equals to 0.8. I collect the whole as a Map.

To sample now :

import org.apache.spark.rdd.PairRDDFunctions val sampleData = data.sampleByKeyExact(false, fractions, 2L) 

or

val sampleData = data.sampleByKeyExact(withReplacement = false, fractions = fractions,seed = 2L) 

You can check the count on your keys or data or data sample :

scala > data.count // [...] // res10: Long = 12  scala > sampleData.count // [...] // res11: Long = 10 

Using DataFrames :

Let's consider the same data (seq) from the previous section.

val df = seq.toDF("keyColumn","value1","value2") df.show // +----------+----------+------+ // | keyColumn|    value1|value2| // +----------+----------+------+ // |2147481832|  23355149|     1| // |2147481832| 973010692|     1| // |2147481832|2134870842|     1| // |2147481832| 541023347|     1| // |2147481832|1682206630|     1| // |2147481832|1138211459|     1| // |2147481832| 852202566|     1| // |2147481832| 201375938|     1| // |2147481832| 486538879|     1| // |2147481832| 919187908|     1| // | 214748183| 919187908|     1| // | 214748183|  91187908|     1| // +----------+----------+------+ 

We will need the underlying RDD to do that on which we creates tuples of the elements in this RDD by defining our key to be the first column :

val data: RDD[(Int, Row)] = df.rdd.keyBy(_.getInt(0)) val fractions: Map[Int, Double] = data.map(_._1)                                       .distinct                                       .map(x => (x, 0.8))                                       .collectAsMap  val sampleData: RDD[Row] = data.sampleByKeyExact(withReplacement = false, fractions, 2L)                                .values  val sampleDataDF: DataFrame = spark.createDataFrame(sampleData, df.schema) // you can use sqlContext.createDataFrame(...) instead for spark 1.6) 

You can now check the count on your keys or df or data sample :

scala > df.count // [...] // res9: Long = 12  scala > sampleDataDF.count // [...] // res10: Long = 10 

Since Spark 1.5.0 you can use DataFrameStatFunctions.sampleBy method:

df.stat.sampleBy("keyColumn", fractions, seed) 
like image 115
eliasah Avatar answered Sep 28 '22 05:09

eliasah