Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use different data augmentation for Subsets in PyTorch

How to use different data augmentation (transforms) for different Subsets in PyTorch?

For instance:

train, test = torch.utils.data.random_split(dataset, [80000, 2000])

train and test will have the same transforms as dataset. How to use custom transforms for these subsets?

like image 834
Fábio Perez Avatar asked Aug 10 '18 08:08

Fábio Perez


People also ask

What is TensorDataset in Pytorch?

The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.


1 Answers

My current solution is not very elegant, but works:

from copy import copy

train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)

test_dataset.dataset.transform = transforms.Compose([
    transforms.Resize(img_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

train_dataset.dataset.transform = transforms.Compose([
    transforms.RandomResizedCrop(img_resolution[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

Basically, I'm defining a new dataset (which is a copy of the original dataset) for one of the splits, and then I define a custom transform for each split.

Note: train_dataset.dataset.transform works since I'm using an ImageFolder dataset, which uses the .tranform attribute to perform the transforms.

If anybody knows a better solution, please share with us!

like image 166
Fábio Perez Avatar answered Nov 10 '22 02:11

Fábio Perez