How to use different data augmentation (transforms) for different Subset
s in PyTorch?
For instance:
train, test = torch.utils.data.random_split(dataset, [80000, 2000])
train
and test
will have the same transforms as dataset
. How to use custom transforms for these subsets?
The Dataset class is an abstract class that is used to define new types of (customs) datasets. Instead, the TensorDataset is a ready to use class to represent your data as list of tensors.
My current solution is not very elegant, but works:
from copy import copy
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)
test_dataset.dataset.transform = transforms.Compose([
transforms.Resize(img_resolution),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset.dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(img_resolution[0]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
Basically, I'm defining a new dataset (which is a copy of the original dataset) for one of the splits, and then I define a custom transform for each split.
Note: train_dataset.dataset.transform
works since I'm using an ImageFolder
dataset, which uses the .tranform
attribute to perform the transforms.
If anybody knows a better solution, please share with us!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With