Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Taking subsets of a pytorch dataset

Tags:

I have a network which I want to train on some dataset (as an example, say CIFAR10). I can create data loader object via

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                         download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                           shuffle=True, num_workers=2) 

My question is as follows: Suppose I want to make several different training iterations. Let's say I want at first to train the network on all images in odd positions, then on all images in even positions and so on. In order to do that, I need to be able to access to those images. Unfortunately, it seems that trainset does not allow such access. That is, trying to do trainset[:1000] or more generally trainset[mask] will throw an error.

I could do instead

trainset.train_data=trainset.train_data[mask] trainset.train_labels=trainset.train_labels[mask] 

and then

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                               shuffle=True, num_workers=2) 

However, that will force me to create a new copy of the full dataset in each iteration (as I already changed trainset.train_data so I will need to redefine trainset). Is there some way to avoid it?

Ideally, I would like to have something "equivalent" to

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,                                               shuffle=True, num_workers=2) 
like image 979
Miriam Farber Avatar asked Nov 22 '17 10:11

Miriam Farber


People also ask

What is RandomSampler in PyTorch?

A Sampler that returns random indices. Public Functions RandomSampler (int64_t size, Dtype index_dtype = torch::kInt64) Constructs a RandomSampler with a size and dtype for the stored indices. The constructor will eagerly allocate all required indices, which is the sequence 0 ...

What is Collate_fn PyTorch?

A custom collate_fn can be used to customize collation, e.g., padding sequential data to a max length of a batch. collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator.

What is Num_workers PyTorch?

num_workers , which denotes the number of processes that generate batches in parallel. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation).


1 Answers

torch.utils.data.Subset is easier, supports shuffle, and doesn't require writing your own sampler:

import torchvision import torch  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                         download=True, transform=None)  evens = list(range(0, len(trainset), 2)) odds = list(range(1, len(trainset), 2)) trainset_1 = torch.utils.data.Subset(trainset, evens) trainset_2 = torch.utils.data.Subset(trainset, odds)  trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,                                             shuffle=True, num_workers=2) trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,                                             shuffle=True, num_workers=2) 
like image 159
jayelm Avatar answered Sep 17 '22 16:09

jayelm