Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Validation dataset in PyTorch using DataLoaders

I want to load MNIST dataset in PyTorch and Torchvision, dividing it into train, validation and test parts. So far I have:

def load_dataset():
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/', train=True, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()])),
        batch_size=batch_size_train, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/data/', train=False, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor()])),
        batch_size=batch_size_test, shuffle=True)

How can I divide the training dataset into training and validation if it's in the DataLoader? I want to use last 10000 examples from the training dataset as a validation dataset (I know that I should do CV for more accurate results, I just want a quick validation here).

like image 310
qalis Avatar asked Sep 27 '20 19:09

qalis


People also ask

What is PyTorch dataloader and how to use it?

For this tutorial, we are going to use the MNIST dataset that’s provided in the torchvision library. In Deep Learning we often train our neural networks in batches of a certain size, DataLoader is a data loading utility in PyTorch that creates an iterable over these batches of the dataset.

Is there a solution for training and validation in PyTorch?

I've found the solution, but it's more complicated than it should be in a normal library. Splitting the training dataset into training and validation in PyTorch turns out to be much harder than it should be. First, split the training set into training and validation subsets (class Subset ), which are not datasets (class Dataset ):

What is PyTorch in Python?

PyTorch is a Python library developed by Facebook to run and train machine learning and deep learning models. Training a deep learning model requires us to convert the data into the format that can be processed by the model. PyTorch provides the torch.utils.data library to make data loading easy with DataSets and Dataloader class.

How to train a deep learning model with PyTorch?

Training a deep learning model requires us to convert the data into the format that can be processed by the model. PyTorch provides the torch.utils.data library to make data loading easy with DataSets and Dataloader class.


Video Answer


2 Answers

Splitting the training dataset into training and validation in PyTorch turns out to be much harder than it should be.

First, split the training set into training and validation subsets (class Subset), which are not datasets (class Dataset):

train_subset, val_subset = torch.utils.data.random_split(
        train, [50000, 10000], generator=torch.Generator().manual_seed(1))

Then get actual data from those datasets:

X_train = train_subset.dataset.data[train_subset.indices]
y_train = train_subset.dataset.targets[train_subset.indices]

X_val = val_subset.dataset.data[val_subset.indices]
y_val = val_subset.dataset.targets[val_subset.indices]

Note that this way we don't have Dataset objects, so we can't use DataLoader objects for batch training. If you want to use DataLoaders, they work directly with Subsets:

train_loader = DataLoader(dataset=train_subset, shuffle=True, batch_size=BATCH_SIZE)
val_loader = DataLoader(dataset=val_subset, shuffle=False, batch_size=BATCH_SIZE)
like image 141
qalis Avatar answered Oct 08 '22 16:10

qalis


If yo'd like to ensure your splits have balanced classes, you can use train_test_split from sklearn.

import torchvision
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

VAL_SIZE = 0.1
BATCH_SIZE = 64

mnist_train = torchvision.datasets.MNIST(
    '/data/',
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
)
mnist_test = torchvision.datasets.MNIST(
    '/data/',
    train=False,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
)

# generate indices: instead of the actual data we pass in integers instead
train_indices, val_indices, _, _ = train_test_split(
    range(len(mnist_train)),
    mnist_train.targets,
    stratify=mnist_train.targets,
    test_size=VAL_SIZE,
)

# generate subset based on indices
train_split = Subset(mnist_train, train_indices)
val_split = Subset(mnist_train, val_indices)

# create batches
train_batches = DataLoader(train_split, batch_size=BATCH_SIZE, shuffle=True)
val_batches = DataLoader(val_split, batch_size=BATCH_SIZE, shuffle=True)
test_batches = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=True)
like image 1
Eric Avatar answered Oct 08 '22 16:10

Eric