Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Train-Valid-Test split for custom dataset using PyTorch and TorchVision

I have some image data for a binary classification task and the images are organised into 2 folders as data/model_data/class-A and data/model_data/class-B.

There are a total of N images. I want to have a 70/20/10 split for train/val/test. I am using PyTorch and Torchvision for the task. Here is the code I have so far.

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets, models

data_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

model_dataset = datasets.ImageFolder(root, transform=data_transform) 
train_count = int(0.7 * total_count) 
valid_count = int(0.2 * total_count)
test_count = total_count - train_count - valid_count
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(model_dataset, (train_count, valid_count, test_count))
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)  
valid_dataset_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER) 
test_dataset_loader  = torch.utils.data.DataLoader(test_dataset , batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKER)
dataloaders = {'train': train_dataset_loader, 'val': valid_dataset_loader, 'test': test_dataset_loader}

I feel that this isn't the correct way to be doing this because of 2 reasons.

  • I am applying the same transform to all the splits. (This is not what I want to do, obviously! The solution for this is most probably the answer here.)
  • Usually people first separate the original data into test/train and then they separate train into train/val, whereas I am directly separating the original data into train/val/test. (Is this correct?)

So, my question is, is what I am doing correct? (Probably not)
And if it is not correct, how do I go about writing the data loaders to achieve the required splits, so that I can apply separate transforms to each of train/test/val?

like image 403
iamshnoo Avatar asked May 15 '20 04:05

iamshnoo


People also ask

How do you split dataset into test and validation?

Split the dataset We can use the train_test_split to first make the split on the original dataset. Then, to get the validation set, we can apply the same function to the train set to get the validation set. In the function below, the test set size is the ratio of the original data we want to use as the test set.

What is train test validation split?

Validation split helps to improve the model performance by fine-tuning the model after each epoch. The test set informs us about the final accuracy of the model after completing the training phase. The training set should not be too small; else, the model will not have enough data to learn.


1 Answers

Usually people first separate the original data into test/train and then they separate train into train/val, whereas I am directly separating the original data into train/val/test. (Is this correct?)

Yes, it's fully correct, readable and totally fine all in all

I am applying the same transform to all the splits. (This is not what I want to do, obviously! The solution for this is most probably the answer here.)

Yes, that answer is a possibility but it's pointlessly verbose tbh. You can use third party tool torchdata, simply instalable with:

pip install torchdata

Documentation can be found here (also disclaimer: I'm the author).

It allows you to map your transformations to any torch.utils.data.Dataset easily (in this case to train). Your code would look like that (only two lines have to change, check the comments, also formatted your code to follow it easier):

import torch
import torchvision

import torchdata as td

data_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

# Single change, makes an instance of torchdata.Dataset
# Works just like PyTorch's torch.utils.data.Dataset, but has
# additional capabilities like .map, cache etc., see project's description
model_dataset = td.datasets.WrapDataset(torchvision.datasets.ImageFolder(root))
# Also you shouldn't use transforms here but below
train_count = int(0.7 * total_count)
valid_count = int(0.2 * total_count)
test_count = total_count - train_count - valid_count
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    model_dataset, (train_count, valid_count, test_count)
)

# Apply transformations here only for train dataset

train_dataset = train_dataset.map(data_transform)

# Rest of the code goes the same

train_dataset_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER
)
valid_dataset_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER
)
test_dataset_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKER
)
dataloaders = {
    "train": train_dataset_loader,
    "val": valid_dataset_loader,
    "test": test_dataset_loader,
}

And yeah, I agree that specifying transform before splitting isn't too clear and IMO this is way more readable.

like image 188
Szymon Maszke Avatar answered Sep 18 '22 13:09

Szymon Maszke