I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:
class Infinite(Dataset):
def __len__(self):
return HPARAMS.batch_size
# return 1<<30 # This causes huge memory usage.
def __getitem__(self, idx):
"""Randomly generates one new example."""
return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
dataset=Infinite(),
batch_size=HPARAMS.batch_size,
num_workers=16,
worker_init_fn=lambda worker_id: np.random.seed(worker_id),
)
while True:
for idx, data in enumerate(infinite_loader):
# forward + backward on "data"
As you can see, the main challenge here is the __len()__
method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.
If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.
I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.
Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)
This seems to be working without periodically duplicating the data:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
BATCH_SIZE = 2
class Infinite(Dataset):
def __len__(self):
return BATCH_SIZE
def __getitem__(self, idx):
return torch.randint(0, 10, (3,))
data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)
batch_count = 0
while True:
batch_count += 1
print(f'Batch {batch_count}:')
data = next(iter(data_loader))
print(data)
# forward + backward on "data"
if batch_count == 5:
break
Result:
Batch 1:
tensor([[4, 7, 7],
[0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
[2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
[8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
[2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
[2, 7, 5]])
So I think the problem is in your function sample_func_to_be_parallelized()
.
Edit: If instead of torch.randint(0, 10, (3,))
I use np.random.randint(10, size=3)
in __getitem__
(as an example of the sample_func_to_be_parallelized()
), then the data is indeed duplicated at each batch. See this issue.
So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized()
, then the workaround is to use
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
and to reset the seed by np.random.seed()
before each call of data = next(iter(data_loader))
.
DataLoader
samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset)
. My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader
and doing it yourself.
You could use the batch_sampler
param, and pass in a custom variant, implemented based on RandomSampler
. This will allow you to keep the parallel loading part of DataLoader
.
That being said, the protocol of iteration based on __len__
and __getitem__
just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__
to just return 1
, your Dataset.__getitem__
to always return a new sample, regardless of the index, and then sampling n
times with replacement from this dataset. Technically, it will ask n
times for the 0-th sample, but since you override __getitem__
to return different samples, this will effectively do what you're looking for.
Try to use cycle
from itertools
. Here is an example for simple dataset:
Code:
from itertools import cycle
import torch
from torch.utils.data import Dataset, DataLoader
# Create some dummy data.
data = torch.tensor([[0, 0],
[1, 1],
[2, 2],
[3, 3]])
class DataSet(Dataset):
"""Our dataset. Iterates over tensor data"""
def __init__(self, data):
self.data = data
self.n = self.data.shape[0]
def __len__(self):
return self.n
def __getitem__(self, idx):
return self.data[idx]
bs = 1 # batch size
workers = 1 # number of workers
dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)
# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
print(i, data)
Output:
batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...
batch size: 2 | number of workers: 2
0 tensor([[0, 0],
[1, 1]])
1 tensor([[2, 2],
[3, 3]])
0 tensor([[0, 0],
[1, 1]])
1 tensor([[2, 2],
...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With