Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Sklearn's class "StratifiedShuffleSplit"

I'm little confused about how does the class StratifiedShuffleSplit of Sklearn works.

The code below is from Géron's book "Hands On Machine Learning", chapter 2, where he does a stratified sampling.

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

Especially, what is been doing in split.split?


like image 739
Rafael Higa Avatar asked Jan 10 '20 00:01

Rafael Higa

Video Answer

2 Answers

Since you did not provide a dataset, I use sklearn sample to answer this question.

Prepare dataset

# generate data
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
data = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
group_label = np.array([0, 0, 0, 1, 1, 1])

This generate a dataset data, which has 6 obseravations and 2 variables. group_label has 2 value, means group 0 and group 1. In this case, group 0 contains 3 samples, same is group 1. To be general, the group size are not need to be the same.

Create a StratifiedShuffleSplit object instance

sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
sss.get_n_splits(data, group_label)



In this step, you can create a instance of StratifiedShuffleSplit, you can tell the function how to split(At random_state = 0,split data 5 times,each time 50% of data will split to test set). However, it only split data when you call it in the next step.

Call the instance, and split data.

# the instance is actually a generater
type(sss.split(data, group_label))

# split data
for train_index, test_index in sss.split(data, group_label):
     print("n_split",,"TRAIN:", train_index, "TEST:", test_index)
     X_train, X_test = X[train_index], X[test_index]
     y_train, y_test = y[train_index], y[test_index]


TRAIN: [5 2 3] TEST: [4 1 0]
TRAIN: [5 1 4] TEST: [0 2 3]
TRAIN: [5 0 2] TEST: [4 3 1]
TRAIN: [4 1 0] TEST: [2 3 5]
TRAIN: [0 5 1] TEST: [3 4 2]

In this step, spliter you defined in the last step will generate 5 split of data one by one. For instance, in the first split, the original data is shuffled and sample 5,2,3 is selected as train set, this is also a stratified sampling by group_label; in the second split, the data is shuffled again and sample 5,1,4 is selected as train set; etc..

like image 133
Travis Avatar answered Oct 23 '22 13:10


split.split() function returns indexes for train samples and test samples. It'll look through it for the number of cross-validation specified and will return each time train and test sample indexes using which train and test dataset can be created by filtering whole dataset.

like image 2
Sunnysinh Solanki Avatar answered Oct 23 '22 13:10

Sunnysinh Solanki