Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch - Can not slice torchvision MNIST dataset

In Pytorch, when using torchvision's MNIST dataset, we can get a digit as follows:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
                          transform=tsfm)

digit_12 = mnist_ds[12]

Though it is possible to slice on many datasets, we cannot slice on this one:

>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.

This is due to a Image.fromarray() in the getItem().

Is it possible to use MNIST dataset without using a Dataloader?


PS: The reason why I would like to avoid using Dataloader is that sending batches one at a time to the GPU slows down the training. I prefer to send the entire dataset to the GPU at once. For this I need to have access to the whole transformed dataset.

like image 561
u2gilles Avatar asked Oct 17 '22 08:10

u2gilles


1 Answers

You can use torch.utils.data.Subset() to get an index based slice of a torch Dataset e.g:

import torch.utils.data as data_utils

indices = torch.arange(12,15)
mnist_12to14 = data_utils.Subset(tr, indices)
like image 61
iacob Avatar answered Oct 21 '22 02:10

iacob