Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Transforms not applying to the dataset

Tags:

python

pytorch

I'm new to pytorch and would like to understand something.

I am loading MNIST as follows:

transform_train = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(size, interpolation=2),
     # transforms.Grayscale(num_output_channels=1),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((mean), (std))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

However, when I explore the dataset, i.e. trainloader.dataset.train_data[0], I am getting a tensor in range [0,255] with shape (28,28).

What am I missing? Is this because the transforms are not applied directly to the dataloader but only in runtime? How can I explore my data otherwise?

like image 253
jerpint Avatar asked Jan 01 '23 22:01

jerpint


1 Answers

The transforms are applied when the __getitem__ method of the Dataset is called. For example look at the __getitem__ method of the MNIST dataset class: https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62

def __getitem__(self, index):
    """
    Args:
        index (int): Index
    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], self.targets[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

The __getitem__ method gets called when you index your MNIST instance for the training set, e.g.:

trainset[0]

For more information on __getitem__: https://docs.python.org/3.6/reference/datamodel.html#object.getitem

The reason why Resize and RandomHorizontalFlip should be before ToTensor is that they act on PIL Images and all the datasets in Pytorch for consistency load the data as PIL Images first. In fact you can see that here they force that behavior through:

img = Image.fromarray(img.numpy(), mode='L')

Once you have the PIL Image of the corresponding index, the transforms are applied with

if self.transform is not None:
    img = self.transform(img)

ToTensor transforms the PIL Image to a torch.Tensor and Normalize subtracts the mean and divides by the standard deviation you provide.

Eventually some transforms are applied to the label with

if self.target_transform is not None:
    target = self.target_transform(target)

Finally the processed image and the processed label are returned. All of this happens in a single trainset[key] call.

import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

transform_train = Compose([Resize(28, interpolation=2),
                           RandomHorizontalFlip(p=0.5),
                           ToTensor(),
                           Normalize([0.], [1.])])

trainset = MNIST(root='./data', train=True, download=True,
                 transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())

shows

(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))
like image 55
iacolippo Avatar answered Jan 12 '23 19:01

iacolippo