I am curious if there is something similar to sklearn's http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html for apache-spark in the latest 2.0.1 release.
So far I could only find https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling which does not seem to be a great fit for splitting heavily imbalanced dataset into train /test samples.
In general, putting 80% of the data in the training set, 10% in the validation set, and 10% in the test set is a good split to start with. The optimum split of the test, validation, and train set depends upon factors such as the use case, the structure of the model, dimension of the data, etc.
In order to avoid this, we can perform something called cross validation. It's very similar to train/test split, but it's applied to more subsets. Meaning, we split our data into k subsets, and train on k-1 one of those subset.
In machine learning, data splitting is typically done to avoid overfitting. That is an instance where a machine learning model fits its training data too well and fails to reliably fit additional data. The original data in a machine learning model is typically taken and split into three or four sets.
If you want to split a data set 80/20 in Spark, you call df. randomSplit([0.80, 0.20], seed) where seed is some integer used to reseed the random number generator.
Let's assume we have a dataset like this:
+---+-----+
| id|label|
+---+-----+
| 0| 0.0|
| 1| 1.0|
| 2| 0.0|
| 3| 1.0|
| 4| 0.0|
| 5| 1.0|
| 6| 0.0|
| 7| 1.0|
| 8| 0.0|
| 9| 1.0|
+---+-----+
This dataset is perfectly balanced, but this approach will work for unbalanced data as well.
Now, let's augment this DataFrame with additional information that will be useful in deciding which rows should go to train set. The steps are as follows:
ratio
.label
and then rank each label's observations using row_number()
.We end up with the following data frame:
+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
| 6| 0.0| 1|
| 2| 0.0| 2|
| 0| 0.0| 3|
| 4| 0.0| 4|
| 8| 0.0| 5|
| 9| 1.0| 1|
| 5| 1.0| 2|
| 3| 1.0| 3|
| 1| 1.0| 4|
| 7| 1.0| 5|
+---+-----+----------+
Note: the rows are shuffled (see: random order in id
column), partitioned by label (see: label
column) and ranked.
Let's assume that we would like to make 80% split. In this case, we would like four 1.0
labels and four 0.0
labels to go to training dataset and one 1.0
label and one 0.0
label to go to test dataset. We have this information in row_number
column, so now we can simply use it in user defined function (if row_number
is less or equal four, the example goes to train set).
After applying the UDF, the resulting data frame is as follows:
+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
| 6| 0.0| 1| true|
| 2| 0.0| 2| true|
| 0| 0.0| 3| true|
| 4| 0.0| 4| true|
| 8| 0.0| 5| false|
| 9| 1.0| 1| true|
| 5| 1.0| 2| true|
| 3| 1.0| 3| true|
| 1| 1.0| 4| true|
| 7| 1.0| 5| false|
+---+-----+----------+----------+
Now, to get the train/test data one has to do:
val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)
These sorting and partitioning steps might be prohibitive for some really big datasets, so I suggest first filtering the dataset as much as possible. The physical plan is as follows:
== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
+- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(label#5, 200)
+- *(1) Project [id#4, label#5]
+- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
+- LocalTableScan [id#4, label#5, _nondeterministic#9
Here's full working example (tested with Spark 2.3.0 and Scala 2.11.12):
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}
class StratifiedTrainTestSplitter {
def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
df.groupBy(label).count().createOrReplaceTempView("labelCounts")
val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
import ss.implicits._
ss.sql(query)
.select("ratioLabel", "trainExamples")
.map((r: Row) => r.getDouble(0) -> r.getLong(1))
.collect()
.toMap
}
def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
val w = Window.partitionBy(col(label)).orderBy(col(label))
val rowNumPartitioner = row_number().over(w)
val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")
dfRowNum.show()
val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)
val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))
dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
}
}
object StratifiedTrainTestSplitter {
def getDf(ss: SparkSession): DataFrame = {
val data = Seq(
(0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
)
ss.createDataFrame(data).toDF("id", "label")
}
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.config(new SparkConf().setMaster("local[1]"))
.getOrCreate()
val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)
df.cache()
df.where(col("isTrainSet") === true).show()
df.where(col("isTrainSet") === false).show()
}
}
Note: the labels are Double
s in this case. If your labels are String
s you'll have to switch types here and there.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With