Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch next(iter(training_loader)) extremely slow, simple data, can't num_workers?

Here x_dat and y_dat are just really long 1-dimensional tensors.

class FunctionDataset(Dataset):
    def __init__(self):
        x_dat, y_dat = data_product()

        self.length = len(x_dat)
        self.y_dat = y_dat
        self.x_dat = x_dat

    def __getitem__(self, index):
        sample = self.x_dat[index]
        label = self.y_dat[index]
        return sample, label

    def __len__(self):
        return self.length

...

data_set = FunctionDataset()

...

training_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(validation_indices)

training_loader = DataLoader(data_set, sampler=training_sampler, batch_size=params['batch_size'], shuffle=False)
validation_loader = DataLoader(data_set, sampler=validation_sampler, batch_size=valid_size, shuffle=False)

I have also tried pinning the memory for the two loaders. Setting num_workers to > 0 gives me run-time errors between the processes (like EOF error and interruption errors). I get my batch with:

x_val, target = next(iter(training_loader))

The entire data-set would fit into memory/gpu but I would like to emulate batches for this experiment. Profiling my process gives me the following:

16276989 function calls (16254744 primitive calls) in 38.779 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1745/1    0.028    0.000   38.780   38.780 {built-in method builtins.exec}
        1    0.052    0.052   38.780   38.780 simple aprox.py:3(<module>)
        1    0.000    0.000   36.900   36.900 simple aprox.py:519(exploreHeatmap)
        1    0.000    0.000   36.900   36.900 simple aprox.py:497(optFromSample)
        1    0.033    0.033   36.900   36.900 simple aprox.py:274(train)
  705/483    0.001    0.000   34.495    0.071 {built-in method builtins.next}
      222    1.525    0.007   34.493    0.155 dataloader.py:311(__next__)
      222    0.851    0.004   12.752    0.057 dataloader.py:314(<listcomp>)
  3016001   11.901    0.000   11.901    0.000 simple aprox.py:176(__getitem__)
       21    0.010    0.000   10.891    0.519 simple aprox.py:413(validationError)
      443    1.380    0.003    9.664    0.022 sampler.py:136(__iter__)
  663/221    2.209    0.003    8.652    0.039 dataloader.py:151(default_collate)
      221    0.070    0.000    6.441    0.029 dataloader.py:187(<listcomp>)
      442    6.369    0.014    6.369    0.014 {built-in method stack}
  3060221    2.799    0.000    5.890    0.000 sampler.py:68(<genexpr>)
  3060000    3.091    0.000    3.091    0.000 tensor.py:382(<lambda>)
      222    0.001    0.000    1.985    0.009 sampler.py:67(__iter__)
      222    1.982    0.009    1.982    0.009 {built-in method randperm}
  663/221    0.002    0.000    1.901    0.009 dataloader.py:192(pin_memory_batch)
      221    0.000    0.000    1.899    0.009 dataloader.py:200(<listcomp>)
....

Suggesting the data loader is immensely slow compared to the remaining activity of my experiment (training the model and lots of other computations etc.). What's going wrong and what would be the best way to speed this up?

like image 562
ZirconCode Avatar asked Dec 18 '22 20:12

ZirconCode


1 Answers

When retrieving a batch with

x, y = next(iter(training_loader))

you actually create a new instance of dataloader iterator at each call (!) See this thread for more infotrmation.
What you should do instead is create the iterator once (per epoch):

training_loader_iter = iter(training_loader)

and then call next for each batch on the iterator

for i in range(num_batches_in_epoch):
  x, y = next(training_loader_iter)

I had similar issue before, and this also made the EOF errors you experience when using multiple workers go away.

like image 95
Shai Avatar answered Jan 22 '23 10:01

Shai