Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark: How to perform undersampling on LabeledPoint?

I've got some unbalanced data in my LabeledPoint. what I want to do is select all positives and n times more negatives (randomly). For example if I have a 100 positives and 30000 negatives, I want to create new LabeledPoint with all 100 positives and 300 negatives (n=3).

And in real scenario I don't how many positives and negatives I have on the beginning.

like image 812
Maju116 Avatar asked Mar 15 '23 02:03

Maju116


1 Answers

Presumably your data is a RDD[LabeledPoint]. You can do something like the following:

val pos = rdd.filter(_.label==1)
val numPos=pos.count()
val neg = rdd.filter(_.label==0).takeSample(false, numPos*3)
val undersample = pos.union(neg)

You can find the docs for takeSample, filter, and union here.

like image 161
Matthew Graves Avatar answered Mar 17 '23 04:03

Matthew Graves