Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch DataLoader fails when the number of examples are not exactly divided by the batch size

Tags:

pytorch

I have coded a custom data loader class in the pytorch. But it fails when iterating through all the number of batches inside an epoch. For example, think I have 100 data examples and my batch size is 9. It will fail in the 10th iteration saying batch size is different which will give a batch size 1 instead of 10. I have put my custom data loader in below. Also I have put how I extract the data from the loader inside the for-loop.

class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader

    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.files = self.__setup_files()

    def __len__(self):
        """
        compute the length of the dataset
        :return: len => length of dataset
        """
        return len(self.files)

    def __setup_files(self):

        file_names = os.listdir(self.data_dir)
        files = []  # initialize to empty list

        for file_name in file_names:

            possible_file = os.path.join(self.data_dir, file_name)
            if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
                files.append(possible_file)

        # return the files list
        return files


    def __getitem__ (self,index):
        sample, _ = librosa.load(self.files[index], 16000)

        if self.transform:
            sample=self.transform(sample)

        sample = torch.from_numpy(sample)    
        return sample


from torch.utils.data import DataLoader 

    my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)

dataloader_my = DataLoader(
        my_dataset,
        batch_size=batch_size,
        num_workers=0,
        shuffle=True)


for (i,batch) in enumerate(dataloader_my,0):  
       print(i)
       if batch.shape[0]!=16:
          print(batch.shape)
          assert batch.shape[0]==16,"Something wrong with the batch size"



like image 879
Shamane Siriwardhana Avatar asked Jun 13 '19 08:06

Shamane Siriwardhana


2 Answers

Short answer

Set drop_last=True to drop the last incomplete batch

Long answer

Based on your code making a reduced version of your Dataloader there is no error for batch sizes.

Using 9 as batch_size and having 100 items the last batch just has one item. Running the code below it produces.

Setting drop_last=False the last line is printed and the 'exception' is printed.

0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])

So the batch produces good enough batch items to make it to 100 in total

from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata


class FlatDirectoryAudioDataset(tdata.Dataset):  # customized dataloader

    def __init__(self):
        self.files = self.__setup_files()

    def __len__(self):
        return len(self.files)

    def __setup_files(self):
        return np.array(range(100))

    def __getitem__(self, index):
        file = self.files[index]
        sample = np.array([file])
        sample = torch.from_numpy(sample)
        return sample


data = FlatDirectoryAudioDataset()

my_dataset = FlatDirectoryAudioDataset()

batch_size = 9

dataloader_my = DataLoader(
    my_dataset,
    batch_size=batch_size,
    num_workers=0,
    shuffle=True,
    drop_last=True)

for i, sample in enumerate(dataloader_my, 0):
    print(i, print(type(sample), sample.shape)
    if sample.shape[0] != batch_size:
        print("Different batch size (last batch)", sample.shape)
like image 158
Clemens Tolboom Avatar answered Oct 15 '22 07:10

Clemens Tolboom


use drop_last=True utils.DataLoader(dataset,batch_size=batch_size,shuffle = True,drop_last=True)

https://pytorch.org/docs/stable/data.html

like image 33
Hany Aly Avatar answered Oct 15 '22 06:10

Hany Aly