Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to ensure that a batch contains samples from all workers with PyTorch's DataLoader?

I want to know how to use torch.utils.data.DataLoader in PyTorch, especially in a multi-worker case.

I found that one batch output from DataLoader always comes from a single worker. I expected that there is a queue in the DataLoader which stores data from all of the workers and DataLoader shuffles them in the queue to output the random batch data. I think this is the way in tf.data.Dataset in Tensorflow. Can we implement a similar function in PyTorch? I want to load a dataset from big serialized files (like Tfrecord) by using multi workers. In this case, mixing the source file in one batch, which means mixing the source of the worker, is important.

Please refer to following code:

import random
import time

import torch


class MyDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 50

    def __getitem__(self, idx):
        info = torch.utils.data.get_worker_info()

        time.sleep(random.uniform(0, 1))
        print("[{}]:{}".format(info.id, idx))
        return idx, info.id


if __name__ == '__main__':
    dataset = MyDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
    for batch in dataloader:
        print(batch)

Output:

[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...

Here, [0, 1, 2, 3, 4] and [0, 0, 0, 0, 0] in [tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])] mean that this batch includes index 0-th to 4-th data came from worker id 0. Note that shuffle=True does not solve this problem which only change the indices of data.

In this case, I want to get a batch like: [tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])].

like image 823
ymfj Avatar asked Nov 07 '22 14:11

ymfj


1 Answers

I've implemented something simple to solve a similar problem, where I have large video files as training data and each worker is responsible for loading and preprocessing a single file and then yielding samples from it. Problem is that as OP describes, with Pytorch's default data loading mechanism, each batch contains samples only from a single video file.

First, let's review the problem. In this simplified code example each worker yields a single Tensor containing its zero-indexed worker id. With a batch size of 32 and 4 workers, we want each batch to contain 8 zeros, 8 ones, 8 twos and 8 threes.

from collections import defaultdict

import torch as T
import torch.utils.data as tdata


class Dataset(tdata.IterableDataset):
    def __init__(self, batch_size: int):
        self._bs = batch_size

    def __iter__(self):
        worker_info = tdata.get_worker_info()
        if not worker_info:
            raise NotImplementedError('Not implemented for num_workers=0')
        for _ in range(self._bs):
            yield T.tensor([worker_info.id])


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
                          batch_size=batch_size,
                          num_workers=num_workers)


for batch in loader:
    counts = defaultdict(int)
    for n in batch.numpy().flatten():
        counts[n] += 1
    print(dict(counts))

Instead the code prints:

{0: 32}
{1: 32}
{2: 32}
{3: 32}

Meaning that the first batch contains samples only from worker 0, the second only from worker 1, etc. To remedy this, we will set the batch size in the DataLoader to batch_size // num_workers and use a simple wrapper over the DataLoader to pool samples from each worker for our batch:

def pooled_batches(loader):
    loader_it = iter(loader)
    while True:
        samples = []
        for _ in range(loader.num_workers):
            try:
                samples.append(next(loader_it))
            except StopIteration:
                pass
        if len(samples) == 0:
            break
        else:
            yield T.cat(samples, dim=0)


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
                          batch_size=per_worker,
                          num_workers=num_workers)

for batch in pooled_batches(loader):
    counts = defaultdict(int)
    for n in batch.numpy().flatten():
        counts[n] += 1
    print(dict(counts))

And the code now prints

{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}

as expected.

like image 118
Agost Biro Avatar answered Nov 24 '22 00:11

Agost Biro