Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does `images, labels = dataiter.next() ` work in PyTorch Tutorial?

From the tutorial cifar10_tutorial, how is images, labels assigned?

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)

images, labels = dataiter.next()

How does the last line know how to automatically assign images, label in images, labels = dataiter.next()?

I checked the DataLoader class and the DataLoaderIter class, but think I need a bit more knowledge of iters in general.

like image 931
njho Avatar asked Nov 25 '18 02:11

njho


People also ask

How does Pytorch data loader work?

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

What does DataLoader return Pytorch?

DataLoader in your case is supposed to return a list. The output of DataLoader is (inputs batch, labels batch) . e.g. Here, the 64 labels corresponds to 64 images in the batch.

What does Len DataLoader return?

The __len__() method returns the total size of the dataset. For example, if your dataset contains 1,00,000 samples, the len method should return 1,00,000.


1 Answers

I think it is crucial to understand the difference between an iterable and an iterator. An iterable is an object that you can iterate over. An Iterator is an object which is used to iterate over an iterable object using the __next__ method, which returns the next item of the object.

A simple example is the following. Consider an iterable and use the next method to call the next item in the list. This will print the next item until the end of the list is reached. If the end is reached it will raise a StopIteration error.

test = (1,2,3)
tester = iter(test)

while True:
    nextItem = next(tester)
    print(nextItem)

The class you refer to above probably has an implementation similar to this, however it returns a tuple containing the image and the label.

like image 139
cvanelteren Avatar answered Oct 05 '22 08:10

cvanelteren