Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you alter the size of a Pytorch Dataset?

Say I am loading MNIST from torchvision.datasets.MNIST, but I only want to load in 10000 images total, how would I slice the data to limit it to only some number of data points? I understand that the DataLoader is a generator yielding data in the size of the specified batch size, but how do you slice datasets?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
like image 674
mikal94305 Avatar asked Jul 01 '17 01:07

mikal94305


People also ask

What is the difference between a PyTorch Dataset and a PyTorch DataLoader?

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

What is torch Utils data Dataset?

torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets.


1 Answers

You can use torch.utils.data.Subset() e.g. for the first 10,000 elements:

import torch.utils.data as data_utils

indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)
like image 194
iacob Avatar answered Oct 05 '22 01:10

iacob