How to use different data augmentation (transforms) for different Subsets in PyTorch?
For instance:
train, test = torch.utils.data.random_split(dataset, [80000, 2000])
train and test will have the same transforms as dataset. How to use custom transforms for these subsets?
The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.
My current solution is not very elegant, but works:
from copy import copy
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)
test_dataset.dataset.transform = transforms.Compose([
    transforms.Resize(img_resolution),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
train_dataset.dataset.transform = transforms.Compose([
    transforms.RandomResizedCrop(img_resolution[0]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
Basically, I'm defining a new dataset (which is a copy of the original dataset) for one of the splits, and then I define a custom transform for each split.
Note: train_dataset.dataset.transform works since I'm using an ImageFolder dataset, which uses the .tranform attribute to perform the transforms.
If anybody knows a better solution, please share with us!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With