Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What are transforms in PyTorch used for?

Tags:

I am new with Pytorch and not very expert in CNN. I have done a successful classifier with the tutorial that they provide Tutorial Pytorch, but I don't really understand what I am doing when loading the data.

They do some data augmentation and normalisation for training, but when I try to modify the parameters, the code does not work.

# Data augmentation and normalization for training # Just normalization for validation data_transforms = {     'train': transforms.Compose([         transforms.RandomResizedCrop(224),         transforms.RandomHorizontalFlip(),         transforms.ToTensor(),         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])     ]),     'val': transforms.Compose([         transforms.Resize(256),         transforms.CenterCrop(224),         transforms.ToTensor(),         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])     ]), } 

Am I extending my training dataset? I don't see the data augmentation.

Why if I modify the value of transforms.RandomResizedCrop(224) the data loading stop working?

Do I need to transform as well the test dataset?

I am a bit confused with this data transformation that they do.

like image 441
carioka88 Avatar asked Apr 24 '18 13:04

carioka88


People also ask

What Transform will convert it into a PyTorch Tensor?

To convert an image to a tensor in PyTorch we use PILToTensor() and ToTensor() transforms. These transforms are provided in the torchvision. transforms package. Using these transforms we can convert a PIL image or a numpy.

What is transforms normalize in PyTorch?

Normalization in PyTorch is done using torchvision. transforms. Normalize(). This normalizes the tensor image with mean and standard deviation.

What is transforms CenterCrop?

Transforms on PIL Image and torch. CenterCrop (size) Crops the given image at the center. ColorJitter ([brightness, contrast, … ]) Randomly change the brightness, contrast, saturation and hue of an image.

Are PyTorch transforms differentiable?

Transforms are also used to normalize data or to create data augmentation (by fliping images for example). But the most important part is that these transformations are non-differentiable operations, therefore they can never be part of a NN since you wouldn't be able to backpropagate error through them.


2 Answers

transforms.Compose just clubs all the transforms provided to it. So, all the transforms in the transforms.Compose are applied to the input one by one.

Train transforms

  1. transforms.RandomResizedCrop(224): This will extract a patch of size (224, 224) from your input image randomly. So, it might pick this path from topleft, bottomright or anywhere in between. So, you are doing data augmentation in this part. Also, changing this value won't play nice with the fully-connected layers in your model, so not advised to change this.
  2. transforms.RandomHorizontalFlip(): Once we have our image of size (224, 224), we can choose to flip it. This is another part of data augmentation.
  3. transforms.ToTensor(): This just converts your input image to PyTorch tensor.
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): This is just input data scaling and these values (mean and std) must have been precomputed for your dataset. Changing these values is also not advised.

Validation transforms

  1. transforms.Resize(256): First your input image is resized to be of size (256, 256)
  2. transforms.CentreCrop(224): Crops the center part of the image of shape (224, 224)

Rest are the same as train

P.S.: You can read more about these transformations in the official docs

like image 159
layog Avatar answered Sep 30 '22 07:09

layog


For ambiguities about data augmentation, I would refer you to this answer:

Data Augmentation in PyTorch

But in short, assume you only have random horizontal flipping transform, when you iterate through a dataset of images, some are returned as original and some are returned as flipped(The original images for the flipped ones are not returned). In other words, the number of returned images in one iteration is the same as the original size of the dataset and is not augmented.

like image 23
Ashkan372 Avatar answered Sep 30 '22 05:09

Ashkan372