I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:
class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)
while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"
As you can see, the main challenge here is the __len()__ method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.
If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.
I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.
Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)
This seems to be working without periodically duplicating the data:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
BATCH_SIZE = 2
class Infinite(Dataset):
    def __len__(self):
        return BATCH_SIZE
    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))
data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)
batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')
    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  
    if batch_count == 5:
        break
Result:
Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])
So I think the problem is in your function sample_func_to_be_parallelized().
Edit: If instead of torch.randint(0, 10, (3,)) I use np.random.randint(10, size=3) in __getitem__ (as an example of the sample_func_to_be_parallelized()), then the data is indeed duplicated at each batch. See this issue.
So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized(), then the workaround is to use
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 
and to reset the seed by np.random.seed() before each call of data = next(iter(data_loader)).
DataLoader samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset). My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader and doing it yourself.
You could use the batch_sampler param, and pass in a custom variant, implemented based on RandomSampler. This will allow you to keep the parallel loading part of DataLoader.
That being said, the protocol of iteration based on __len__ and __getitem__ just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__ to just return 1, your Dataset.__getitem__ to always return a new sample, regardless of the index, and then sampling n times with replacement from this dataset. Technically, it will ask n times for the 0-th sample, but since you override __getitem__ to return different samples, this will effectively do what you're looking for.
Try to use cycle from itertools. Here is an example for simple dataset:
Code:
from itertools import cycle
import torch
from torch.utils.data import Dataset, DataLoader
# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])
class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""
    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]
    def __len__(self):
        return self.n
    def __getitem__(self, idx):
        return self.data[idx]
bs = 1  # batch size
workers = 1  # number of workers
dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)
# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)
Output:
batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...
batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
...
                        If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With