Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: Apply data augmentation on training data after random_split

I have a dataset that does not have separate folders for training and testing. I want to apply data augmentation with transforms only on the training data after doing the split

 train_data, valid_data = D.random_split(dataset, lengths=[train_size, valid_size])

Does anyone know how this can be achieved? I have a custom dataset with initialization and getitem. The training and validation datasets are further passed to the DataLoader.

like image 279
alice Avatar asked Apr 13 '26 18:04

alice


1 Answers

You can have a custom Dataset only for the transformations:

class TrDataset(Dataset):
  def __init__(self, base_dataset, transformations):
    super(TrDataset, self).__init__()
    self.base = base_dataset
    self.transformations = transformations

  def __len__(self):
    return len(self.base)

  def __getitem__(self, idx):
    x, y = self.base[idx]
    return self.transformations(x), y

Once you have this Dataset wrapper, you can have different transformations for the train and validation sets:

raw_train_data, raw_valid_data = D.random_split(dataset, lengths=[train_size, valid_size])
train_data = TrDataset(raw_train_data, train_transforms)
valid_data = TrDataset(raw_valid_data, val_transforms)
like image 168
Shai Avatar answered Apr 16 '26 19:04

Shai