I am training a GANS on the Cifar-10 dataset in PyTorch (and hence don't need train/val/test splits), and I want to be able to combine the torchvision.datasets.CIFAR10
in the snippet below to form one single torch.utils.data.DataLoader
iterator. My current solution is something like :
import torchvision
import torch
batch_size = 128
cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
cifar_dl1 = torch.utils.data.DataLoader(cifar_trainset, batch_size=batch_size, num_workers=12, persistent_workers=True,
shuffle=True, pin_memory=True)
cifar_dl2 = torch.utils.data.DataLoader(cifar_testset, batch_size=batch_size, num_workers=12, persistent_workers=True,
shuffle=True, pin_memory=True)
And then in my training loop I have something like:
for dl in [cifar_dl1, cifar_l2]:
for data in dl:
# training
The problem with this approach in a multi-threaded context, where I have found for my setup and this task that the optimal number of workers is 12, is that now I am declaring 24 workers in total which is clearly too many, not to mention the start-up time costs associated with re-iterating over each dataloader in spite of the benefits of the persistent workers flag for each.
Any solutions to this problem much appreciated.
You can use ConcatDataset
from torch.utils.data
module.
Code Snippet:
import torch
import torchvision
batch_size = 128
cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)
cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)
cifar_dataset = torch.utils.data.ConcatDataset([cifar_trainset, cifar_testset])
cifar_dataloader = torch.utils.data.DataLoader(cifar_dataset, batch_size=batch_size, num_workers=12, persistent_workers=True,
shuffle=True, pin_memory=True)
for data in cifar_dataloader:
# training
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With