I have coded a custom data loader class in the pytorch. But it fails when iterating through all the number of batches inside an epoch. For example, think I have 100 data examples and my batch size is 9. It will fail in the 10th iteration saying batch size is different which will give a batch size 1 instead of 10. I have put my custom data loader in below. Also I have put how I extract the data from the loader inside the for-loop.
class FlatDirectoryAudioDataset(tdata.Dataset): #customized dataloader
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.files = self.__setup_files()
def __len__(self):
"""
compute the length of the dataset
:return: len => length of dataset
"""
return len(self.files)
def __setup_files(self):
file_names = os.listdir(self.data_dir)
files = [] # initialize to empty list
for file_name in file_names:
possible_file = os.path.join(self.data_dir, file_name)
if os.path.isfile(possible_file) and (file_name.lower().endswith('.wav') or file_name.lower().endswith('.mp3')): #&& (possible_file.lower().endswith('.wav') or possible_file.lower().endswith('.mp3')):
files.append(possible_file)
# return the files list
return files
def __getitem__ (self,index):
sample, _ = librosa.load(self.files[index], 16000)
if self.transform:
sample=self.transform(sample)
sample = torch.from_numpy(sample)
return sample
from torch.utils.data import DataLoader
my_dataset=FlatDirectoryAudioDataset(source_directory,source_folder,source_label,transform = None,label=True)
dataloader_my = DataLoader(
my_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True)
for (i,batch) in enumerate(dataloader_my,0):
print(i)
if batch.shape[0]!=16:
print(batch.shape)
assert batch.shape[0]==16,"Something wrong with the batch size"
Set drop_last=True
to drop the last incomplete batch
Based on your code making a reduced version of your Dataloader there is no error for batch sizes.
Using 9 as batch_size
and having 100 items the last batch just has one item. Running the code below it produces.
Setting drop_last=False the last line is printed and the 'exception' is printed.
0 <class 'torch.Tensor'> torch.Size([9, 1])
1 <class 'torch.Tensor'> torch.Size([9, 1])
2 <class 'torch.Tensor'> torch.Size([9, 1])
3 <class 'torch.Tensor'> torch.Size([9, 1])
4 <class 'torch.Tensor'> torch.Size([9, 1])
5 <class 'torch.Tensor'> torch.Size([9, 1])
6 <class 'torch.Tensor'> torch.Size([9, 1])
7 <class 'torch.Tensor'> torch.Size([9, 1])
8 <class 'torch.Tensor'> torch.Size([9, 1])
9 <class 'torch.Tensor'> torch.Size([9, 1])
10 <class 'torch.Tensor'> torch.Size([9, 1])
# depends on drop_last=True|False
11 <class 'torch.Tensor'> torch.Size([1, 1])
Different batch size (last batch) torch.Size([1, 1])
So the batch produces good enough batch items to make it to 100 in total
from torch.utils.data import DataLoader
import os
import numpy as np
import torch
import torch.utils.data.dataset as tdata
class FlatDirectoryAudioDataset(tdata.Dataset): # customized dataloader
def __init__(self):
self.files = self.__setup_files()
def __len__(self):
return len(self.files)
def __setup_files(self):
return np.array(range(100))
def __getitem__(self, index):
file = self.files[index]
sample = np.array([file])
sample = torch.from_numpy(sample)
return sample
data = FlatDirectoryAudioDataset()
my_dataset = FlatDirectoryAudioDataset()
batch_size = 9
dataloader_my = DataLoader(
my_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True,
drop_last=True)
for i, sample in enumerate(dataloader_my, 0):
print(i, print(type(sample), sample.shape)
if sample.shape[0] != batch_size:
print("Different batch size (last batch)", sample.shape)
use drop_last=True utils.DataLoader(dataset,batch_size=batch_size,shuffle = True,drop_last=True)
https://pytorch.org/docs/stable/data.html
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With