Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does PyTorch DataLoader interact with a PyTorch dataset to transform batches?

I'm creating a custom dataset for NLP-related tasks.

In the PyTorch custom datast tutorial, we see that the __getitem__() method leaves room for a transform before it returns a sample:

def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
       
        ### SOME DATA MANIPULATION HERE ###

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)

        return sample

However, the code here:

        if torch.is_tensor(idx):
            idx = idx.tolist()

implies that multiple items should be able to be retrieved at a time which leaves me wondering:

  1. How does that transform work on multiple items? Take the custom transforms in the tutorial for example. They do not look like they could be applied to a batch of samples in a single call.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample?

like image 853
rocksNwaves Avatar asked Feb 25 '21 14:02

rocksNwaves


People also ask

How does PyTorch data loader work?

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.

What is the difference between a PyTorch dataset and a PyTorch DataLoader?

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

Why is DataLoader used in PyTorch?

Creating a PyTorch Dataset and managing it with Dataloader keeps your data manageable and helps to simplify your machine learning pipeline. a Dataset stores all your data, and Dataloader is can be used to iterate through the data, manage batches, transform the data, and much more.

What is batch in DataLoader?

PyTorch dataloader batch size Batch size is defined as the number of samples processed before the model is updated. The batch size is equal to the number of samples in the training data.


2 Answers

  1. How does that transform work on multiple items? They work on multiple items through use of the data loader. By using transforms, you are specifying what should happen to a single emission of data (e.g., batch_size=1). The data loader takes your specified batch_size and makes n calls to the __getitem__ method in the torch data set, applying the transform to each sample sent into training/validation. It then collates n samples into your batch size emitted from the data loader.

  2. Related, how does a DataLoader retrieve a batch of multiple samples in parallel and apply said transform if the transform can only be applied to a single sample? Hopefully above makes sense to you. Parallelization is done by the torch data set class and the data loader, where you specify num_workers. Torch will pickle the data set and spread it across workers.

like image 67
John Stud Avatar answered Oct 29 '22 04:10

John Stud


from the documentation of transforms from torchvision:

All transformations accept PIL Image, Tensor Image or batch of Tensor Images as input. Tensor Image is a tensor with (C, H, W) shape, where C is a number of channels, H and W are image height and width. Batch of Tensor Images is a tensor of (B, C, H, W) shape, where B is a number of images in the batch. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch.

This means that you can pass a batch of images, and the transform will be applied to the whole batch, as long as it respects the shape. The list indexes act on the iloc from the dataframe, which selects either a single index or a list of them, returning the requested subset.

like image 37
Maura Pintor Avatar answered Oct 29 '22 03:10

Maura Pintor