I was trying to reset the dataloader manually but was unable. I tried everything here https://discuss.pytorch.org/t/how-could-i-reset-dataloader-or-count-data-batch-with-iter-instead-of-epoch/22902/4 but no luck. Anyone know how to reset the data loader AND also have the suffle/randomness of the batches not be broken?
It seems that dataloader shuffles the whole data and forms new batches at the beginning of every epoch.
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.
Shuffling the data: shuffle is another argument passed to the DataLoader class. The argument takes in a Boolean value (True/False). If shuffle is set to True , then all the samples are shuffled and loaded in batches.
DataLoader in your case is supposed to return a list. The output of DataLoader is (inputs batch, labels batch) . e.g. Here, the 64 labels corresponds to 64 images in the batch.
To reset a DataLoader then just enumerate the loader again. Each call to enumerate(loader)
starts from the beginning.
To not break transformers that use random values, then reset the random seed each time the DataLoader is initialized.
def seed_init_fn(x):
seed = args.seed + x
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
return
loader = torch.utils.data.DataLoader(...., worker_init_fn = seed_init_fn)
while True:
for i,data in enumerate(loader):
# will always yield same data
See worker_init_fn
in the documents:
https://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader
Here is a better example:
https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With