Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch data loader multiple iterations

Tags:

python

pytorch

i use iris-dataset to train a simple network with pytorch.

trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=2)

dataiter = iter(trainloader)

the dataset itself has only 150 data points, and pytorch dataloader iterates jus t once over the whole dataset, because of the batch size of 150.

My question is now, is there generally any way to tell dataloader of pytorch to repeat over the dataset if it's once done with iteration?

thnaks

update

got it runnning :) just created a sub class of dataloader and implemented my own __next__()

like image 975
arash javanmard Avatar asked Dec 08 '17 12:12

arash javanmard


Video Answer


1 Answers

To complement the previous answers. To be comparable between datasets, it is often better to use the total number of steps instead of the total number of epochs as a hyper-parameter. That is because the number of iterations should not relly on the dataset size, but on its complexity.

I am using the following code for training. It ensures that the data loader re-shuffles the data every time it is re-initiated.

# main training loop
    generator = iter(trainloader)
    for i in range(max_steps):

        try:
            # Samples the batch
            x, y = next(generator)
        except StopIteration:
            # restart the generator if the previous generator is exhausted.
            generator = iter(trainloader)
            x, y = next(generator)

I will agree that is not the most elegant solution, but it keeps me from having to rely on epochs for training.

like image 86
Victor Zuanazzi Avatar answered Oct 16 '22 15:10

Victor Zuanazzi