Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cross validation for MNIST dataset with pytorch and sklearn

I am new to pytorch and are trying to implement a feed forward neural network to classify the mnist data set. I have some problems when trying to use cross-validation. My data has the following shapes: x_train: torch.Size([45000, 784]) and y_train: torch.Size([45000])

I tried to use KFold from sklearn.

kfold =KFold(n_splits=10)

Here is the first part of my train method where I'm dividing the data into folds:

for  train_index, test_index in kfold.split(x_train, y_train): 
        x_train_fold = x_train[train_index]
        x_test_fold = x_test[test_index]
        y_train_fold = y_train[train_index]
        y_test_fold = y_test[test_index]
        print(x_train_fold.shape)
        for epoch in range(epochs):
         ...

The indices for the y_train_fold variable is right, it's simply: [ 0 1 2 ... 4497 4498 4499], but it's not for x_train_fold, which is [ 4500 4501 4502 ... 44997 44998 44999]. And the same goes for the test folds.

For the first iteration I want the varibale x_train_fold to be the first 4500 pictures, in other words to have the shape torch.Size([4500, 784]), but it has the shape torch.Size([40500, 784])

Any tips on how to get this right?

like image 490
Kimmen Avatar asked Nov 22 '19 14:11

Kimmen


People also ask

What is stratified k-fold cross validation?

Stratified K-Folds cross-validator. Provides train/test indices to split data in train/test sets. This cross-validation object is a variation of KFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class.

Why do we need cross validation?

The purpose of cross–validation is to test the ability of a machine learning model to predict new data. It is also used to flag problems like overfitting or selection bias and gives insights on how the model will generalize to an independent dataset.


1 Answers

I think you're confused!

Ignore the second dimension for a while, When you've 45000 points, and you use 10 fold cross-validation, what's the size of each fold? 45000/10 i.e. 4500.

It means that each of your fold will contain 4500 data points, and one of those fold will be used for testing, and the remaining for training i.e.

For testing: one fold => 4500 data points => size: 4500
For training: remaining folds => 45000-4500 data points => size: 45000-4500=40500

Thus, for first iteration, the first 4500 data points (corresponding to indices) will be used for testing and the rest for training. (Check below image)

Given your data is x_train: torch.Size([45000, 784]) and y_train: torch.Size([45000]), this is how your code should look like:

for train_index, test_index in kfold.split(x_train, y_train):  
    print(train_index, test_index)

    x_train_fold = x_train[train_index] 
    y_train_fold = y_train[train_index] 
    x_test_fold = x_train[test_index] 
    y_test_fold = y_train[test_index] 

    print(x_train_fold.shape, y_train_fold.shape) 
    print(x_test_fold.shape, y_test_fold.shape) 
    break 

[ 4500  4501  4502 ... 44997 44998 44999] [   0    1    2 ... 4497 4498 4499]
torch.Size([40500, 784]) torch.Size([40500])
torch.Size([4500, 784]) torch.Size([4500])

So, when you say

I want the variable x_train_fold to be the first 4500 picture... shape torch.Size([4500, 784]).

you're wrong. this size corresonds to x_test_fold. In the first iteration, based on 10 folds, x_train_fold will have 40500 points, thus its size is supposed to be torch.Size([40500, 784]).

K-fold validation image

like image 58
kHarshit Avatar answered Nov 20 '22 05:11

kHarshit