I want to know how to use torch.utils.data.DataLoader
in PyTorch, especially in a multi-worker case.
I found that one batch output from DataLoader
always comes from a single worker.
I expected that there is a queue in the DataLoader which stores data from all of the workers and DataLoader shuffles them in the queue to output the random batch data. I think this is the way in tf.data.Dataset
in Tensorflow.
Can we implement a similar function in PyTorch? I want to load a dataset from big serialized files (like Tfrecord
) by using multi workers. In this case, mixing the source file in one batch, which means mixing the source of the worker, is important.
Please refer to following code:
import random
import time
import torch
class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50
def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()
time.sleep(random.uniform(0, 1))
print("[{}]:{}".format(info.id, idx))
return idx, info.id
if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)
Output:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...
Here, [0, 1, 2, 3, 4]
and [0, 0, 0, 0, 0]
in [tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
mean that this batch includes index 0-th to 4-th data came from worker id 0
.
Note that shuffle=True
does not solve this problem which only change the indices of data.
In this case, I want to get a batch like: [tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])]
.
I've implemented something simple to solve a similar problem, where I have large video files as training data and each worker is responsible for loading and preprocessing a single file and then yielding samples from it. Problem is that as OP describes, with Pytorch's default data loading mechanism, each batch contains samples only from a single video file.
First, let's review the problem. In this simplified code example each worker yields a single Tensor containing its zero-indexed worker id. With a batch size of 32 and 4 workers, we want each batch to contain 8 zeros, 8 ones, 8 twos and 8 threes.
from collections import defaultdict
import torch as T
import torch.utils.data as tdata
class Dataset(tdata.IterableDataset):
def __init__(self, batch_size: int):
self._bs = batch_size
def __iter__(self):
worker_info = tdata.get_worker_info()
if not worker_info:
raise NotImplementedError('Not implemented for num_workers=0')
for _ in range(self._bs):
yield T.tensor([worker_info.id])
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers)
for batch in loader:
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
Instead the code prints:
{0: 32}
{1: 32}
{2: 32}
{3: 32}
Meaning that the first batch contains samples only from worker 0, the second only from worker 1, etc. To remedy this, we will set the batch size in the DataLoader
to batch_size // num_workers
and use a simple wrapper over the DataLoader
to pool samples from each worker for our batch:
def pooled_batches(loader):
loader_it = iter(loader)
while True:
samples = []
for _ in range(loader.num_workers):
try:
samples.append(next(loader_it))
except StopIteration:
pass
if len(samples) == 0:
break
else:
yield T.cat(samples, dim=0)
batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
batch_size=per_worker,
num_workers=num_workers)
for batch in pooled_batches(loader):
counts = defaultdict(int)
for n in batch.numpy().flatten():
counts[n] += 1
print(dict(counts))
And the code now prints
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
as expected.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With