I'm new to pytorch and would like to understand something.
I am loading MNIST as follows:
transform_train = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(size, interpolation=2),
# transforms.Grayscale(num_output_channels=1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize((mean), (std))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
However, when I explore the dataset, i.e. trainloader.dataset.train_data[0]
, I am getting a tensor in range [0,255] with shape (28,28).
What am I missing? Is this because the transforms are not applied directly to the dataloader but only in runtime? How can I explore my data otherwise?
The transforms are applied when the __getitem__
method of the Dataset
is called. For example look at the __getitem__
method of the MNIST
dataset class: https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
The __getitem__
method gets called when you index your MNIST
instance for the training set, e.g.:
trainset[0]
For more information on __getitem__
: https://docs.python.org/3.6/reference/datamodel.html#object.getitem
The reason why Resize
and RandomHorizontalFlip
should be before ToTensor
is that they act on PIL Images and all the datasets in Pytorch for consistency load the data as PIL Image
s first. In fact you can see that here they force that behavior through:
img = Image.fromarray(img.numpy(), mode='L')
Once you have the PIL Image
of the corresponding index, the transforms are applied with
if self.transform is not None:
img = self.transform(img)
ToTensor
transforms the PIL Image
to a torch.Tensor
and Normalize
subtracts the mean and divides by the standard deviation you provide.
Eventually some transforms are applied to the label with
if self.target_transform is not None:
target = self.target_transform(target)
Finally the processed image and the processed label are returned. All of this happens in a single trainset[key]
call.
import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
transform_train = Compose([Resize(28, interpolation=2),
RandomHorizontalFlip(p=0.5),
ToTensor(),
Normalize([0.], [1.])])
trainset = MNIST(root='./data', train=True, download=True,
transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())
shows
(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With