Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Implementing an “infinite loop” Dataset & DataLoader in PyTorch




I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

As you can see, the main challenge here is the __len()__ method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.

If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.

I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.

Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)

like image 684
Covi Avatar asked Jan 25 '19 05:01


3 Answers

This seems to be working without periodically duplicating the data:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))

data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    # forward + backward on "data"  

    if batch_count == 5:


Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

So I think the problem is in your function sample_func_to_be_parallelized().

Edit: If instead of torch.randint(0, 10, (3,)) I use np.random.randint(10, size=3) in __getitem__ (as an example of the sample_func_to_be_parallelized()), then the data is indeed duplicated at each batch. See this issue.

So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized(), then the workaround is to use

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

and to reset the seed by np.random.seed() before each call of data = next(iter(data_loader)).

like image 143
Andreas K. Avatar answered Nov 18 '22 16:11

Andreas K.

DataLoader samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset). My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader and doing it yourself. You could use the batch_sampler param, and pass in a custom variant, implemented based on RandomSampler. This will allow you to keep the parallel loading part of DataLoader.

That being said, the protocol of iteration based on __len__ and __getitem__ just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__ to just return 1, your Dataset.__getitem__ to always return a new sample, regardless of the index, and then sampling n times with replacement from this dataset. Technically, it will ask n times for the 0-th sample, but since you override __getitem__ to return different samples, this will effectively do what you're looking for.

like image 39
Jatentaki Avatar answered Nov 18 '22 16:11


Try to use cycle from itertools. Here is an example for simple dataset:


from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader

# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])

class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]

bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)


batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])

batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
like image 1
trsvchn Avatar answered Nov 18 '22 18:11
