Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch DataLoader uses identical random transformation across each epoch

There is a bug in PyTorch/Numpy where when loading batches in parallel with a DataLoader (i.e. setting num_workers > 1), the same NumPy random seed is used for each worker, resulting in any random functions applied being identical across parallelized batches. This can be resolved by passing a seed generator to the worker_init_fn argument like so.

However the issue persists across multiple epochs.

Minimal example:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 4

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, 
                        num_workers=2, 
                        worker_init_fn = lambda x: np.random.seed(x))

for epoch in range(3):
    print(f'\nEpoch {epoch}')
    for batch in dataloader:
        print(batch)

As you can see, while parallelized batches within an epoch now produce different results, the results are identical across epochs:

Epoch 0
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])

Epoch 1
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])

Epoch 2
tensor([[684, 559]])
tensor([[ 37, 235]])
tensor([[629, 192]])
tensor([[908,  72]])

How can this be behaviour be fixed?


Using an empty argument e.g. worker_init_fn = lambda _: np.random.seed() appears to fix this - are there any issues with this workaround?

like image 754
iacob Avatar asked Jan 29 '26 06:01

iacob


1 Answers

The best way I can think of is to use the seed set by pytorch for numpy and random:

import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    if torch_seed >= 2**30:  # make sure torch_seed + workder_id < 2**32
        torch_seed = torch_seed % 2**30
    np.random.seed(torch_seed + worker_id)

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 4

dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, 
                        num_workers=2, 
                        worker_init_fn = worker_init_fn)

for epoch in range(3):
    print(f'\nEpoch {epoch}')
    for batch in dataloader:
        print(batch)

Output:

Epoch 0
tensor([[593, 191]])
tensor([[207, 469]])
tensor([[976, 714]])
tensor([[ 13, 119]])

Epoch 1
tensor([[836, 664]])
tensor([[138, 836]])
tensor([[409, 313]])
tensor([[  2, 221]])

Epoch 2
tensor([[269, 888]])
tensor([[315, 619]])
tensor([[892, 774]])
tensor([[ 70, 771]])

Alternatively, you can use int(time.time()) to seed numpy and random, assuming each epoch takes more than 1 second to run.

like image 68
Tu Bui Avatar answered Jan 31 '26 21:01

Tu Bui