Say I am loading MNIST from torchvision.datasets.MNIST, but I only want to load in 10000 images total, how would I slice the data to limit it to only some number of data points? I understand that the DataLoader is a generator yielding data in the size of the specified batch size, but how do you slice datasets?
tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets.
You can use torch.utils.data.Subset()
e.g. for the first 10,000 elements:
import torch.utils.data as data_utils
indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With